Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def unalias_group(expression: exp.Expression) -> exp.Expression:
 13    """
 14    Replace references to select aliases in GROUP BY clauses.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 19        'SELECT a AS b FROM x GROUP BY 1'
 20
 21    Args:
 22        expression: the expression that will be transformed.
 23
 24    Returns:
 25        The transformed expression.
 26    """
 27    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 28        aliased_selects = {
 29            e.alias: i
 30            for i, e in enumerate(expression.parent.expressions, start=1)
 31            if isinstance(e, exp.Alias)
 32        }
 33
 34        for group_by in expression.expressions:
 35            if (
 36                isinstance(group_by, exp.Column)
 37                and not group_by.table
 38                and group_by.name in aliased_selects
 39            ):
 40                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 41
 42    return expression
 43
 44
 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 46    """
 47    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 48
 49    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 50
 51    Args:
 52        expression: the expression that will be transformed.
 53
 54    Returns:
 55        The transformed expression.
 56    """
 57    if (
 58        isinstance(expression, exp.Select)
 59        and expression.args.get("distinct")
 60        and expression.args["distinct"].args.get("on")
 61        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 62    ):
 63        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 64        outer_selects = expression.selects
 65        row_number = find_new_name(expression.named_selects, "_row_number")
 66        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 67        order = expression.args.get("order")
 68
 69        if order:
 70            window.set("order", order.pop().copy())
 71        else:
 72            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 73
 74        window = exp.alias_(window, row_number)
 75        expression.select(window, copy=False)
 76
 77        return (
 78            exp.select(*outer_selects)
 79            .from_(expression.subquery("_t"))
 80            .where(exp.column(row_number).eq(1))
 81        )
 82
 83    return expression
 84
 85
 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 87    """
 88    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 89
 90    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 91    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 92
 93    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 94    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 95    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 96    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 97    """
 98    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 99        taken = set(expression.named_selects)
100        for select in expression.selects:
101            if not select.alias_or_name:
102                alias = find_new_name(taken, "_c")
103                select.replace(exp.alias_(select, alias))
104                taken.add(alias)
105
106        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
107        qualify_filters = expression.args["qualify"].pop().this
108
109        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
110        for expr in qualify_filters.find_all(select_candidates):
111            if isinstance(expr, exp.Window):
112                alias = find_new_name(expression.named_selects, "_w")
113                expression.select(exp.alias_(expr, alias), copy=False)
114                column = exp.column(alias)
115
116                if isinstance(expr.parent, exp.Qualify):
117                    qualify_filters = column
118                else:
119                    expr.replace(column)
120            elif expr.name not in expression.named_selects:
121                expression.select(expr.copy(), copy=False)
122
123        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
124
125    return expression
126
127
128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
129    """
130    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
131    other expressions. This transforms removes the precision from parameterized types in expressions.
132    """
133    for node in expression.find_all(exp.DataType):
134        node.set(
135            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
136        )
137
138    return expression
139
140
141def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
142    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
143    if isinstance(expression, exp.Select):
144        for join in expression.args.get("joins") or []:
145            unnest = join.this
146
147            if isinstance(unnest, exp.Unnest):
148                alias = unnest.args.get("alias")
149                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
150
151                expression.args["joins"].remove(join)
152
153                for e, column in zip(unnest.expressions, alias.columns if alias else []):
154                    expression.append(
155                        "laterals",
156                        exp.Lateral(
157                            this=udtf(this=e),
158                            view=True,
159                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
160                        ),
161                    )
162
163    return expression
164
165
166def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
167    """Convert explode/posexplode into unnest (used in hive -> presto)."""
168
169    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
170        if isinstance(expression, exp.Select):
171            from sqlglot.optimizer.scope import Scope
172
173            taken_select_names = set(expression.named_selects)
174            taken_source_names = {name for name, _ in Scope(expression).references}
175
176            def new_name(names: t.Set[str], name: str) -> str:
177                name = find_new_name(names, name)
178                names.add(name)
179                return name
180
181            arrays: t.List[exp.Condition] = []
182            series_alias = new_name(taken_select_names, "pos")
183            series = exp.alias_(
184                exp.Unnest(
185                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
186                ),
187                new_name(taken_source_names, "_u"),
188                table=[series_alias],
189            )
190
191            # we use list here because expression.selects is mutated inside the loop
192            for select in expression.selects.copy():
193                explode = select.find(exp.Explode)
194
195                if explode:
196                    pos_alias = ""
197                    explode_alias = ""
198
199                    if isinstance(select, exp.Alias):
200                        explode_alias = select.alias
201                        alias = select
202                    elif isinstance(select, exp.Aliases):
203                        pos_alias = select.aliases[0].name
204                        explode_alias = select.aliases[1].name
205                        alias = select.replace(exp.alias_(select.this, "", copy=False))
206                    else:
207                        alias = select.replace(exp.alias_(select, ""))
208                        explode = alias.find(exp.Explode)
209                        assert explode
210
211                    is_posexplode = isinstance(explode, exp.Posexplode)
212                    explode_arg = explode.this
213
214                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
215                    if isinstance(explode_arg, exp.Column):
216                        taken_select_names.add(explode_arg.output_name)
217
218                    unnest_source_alias = new_name(taken_source_names, "_u")
219
220                    if not explode_alias:
221                        explode_alias = new_name(taken_select_names, "col")
222
223                        if is_posexplode:
224                            pos_alias = new_name(taken_select_names, "pos")
225
226                    if not pos_alias:
227                        pos_alias = new_name(taken_select_names, "pos")
228
229                    alias.set("alias", exp.to_identifier(explode_alias))
230
231                    column = exp.If(
232                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
233                        true=exp.column(explode_alias),
234                    )
235
236                    explode.replace(column)
237
238                    if is_posexplode:
239                        expressions = expression.expressions
240                        expressions.insert(
241                            expressions.index(alias) + 1,
242                            exp.If(
243                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
244                                true=exp.column(pos_alias),
245                            ).as_(pos_alias),
246                        )
247                        expression.set("expressions", expressions)
248
249                    if not arrays:
250                        if expression.args.get("from"):
251                            expression.join(series, copy=False)
252                        else:
253                            expression.from_(series, copy=False)
254
255                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
256                    arrays.append(size)
257
258                    # trino doesn't support left join unnest with on conditions
259                    # if it did, this would be much simpler
260                    expression.join(
261                        exp.alias_(
262                            exp.Unnest(
263                                expressions=[explode_arg.copy()],
264                                offset=exp.to_identifier(pos_alias),
265                            ),
266                            unnest_source_alias,
267                            table=[explode_alias],
268                        ),
269                        join_type="CROSS",
270                        copy=False,
271                    )
272
273                    if index_offset != 1:
274                        size = size - 1
275
276                    expression.where(
277                        exp.column(series_alias)
278                        .eq(exp.column(pos_alias))
279                        .or_(
280                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
281                        ),
282                        copy=False,
283                    )
284
285            if arrays:
286                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
287
288                if index_offset != 1:
289                    end = end - (1 - index_offset)
290                series.expressions[0].set("end", end)
291
292        return expression
293
294    return _explode_to_unnest
295
296
297PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
298
299
300def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
301    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
302    if (
303        isinstance(expression, PERCENTILES)
304        and not isinstance(expression.parent, exp.WithinGroup)
305        and expression.expression
306    ):
307        column = expression.this.pop()
308        expression.set("this", expression.expression.pop())
309        order = exp.Order(expressions=[exp.Ordered(this=column)])
310        expression = exp.WithinGroup(this=expression, expression=order)
311
312    return expression
313
314
315def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
316    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
317    if (
318        isinstance(expression, exp.WithinGroup)
319        and isinstance(expression.this, PERCENTILES)
320        and isinstance(expression.expression, exp.Order)
321    ):
322        quantile = expression.this.this
323        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
324        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
325
326    return expression
327
328
329def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
330    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
331    if isinstance(expression, exp.With) and expression.recursive:
332        next_name = name_sequence("_c_")
333
334        for cte in expression.expressions:
335            if not cte.args["alias"].columns:
336                query = cte.this
337                if isinstance(query, exp.Union):
338                    query = query.this
339
340                cte.args["alias"].set(
341                    "columns",
342                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
343                )
344
345    return expression
346
347
348def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
349    """Replace 'epoch' in casts by the equivalent date literal."""
350    if (
351        isinstance(expression, (exp.Cast, exp.TryCast))
352        and expression.name.lower() == "epoch"
353        and expression.to.this in exp.DataType.TEMPORAL_TYPES
354    ):
355        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
356
357    return expression
358
359
360def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
361    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
362    if isinstance(expression, exp.Select):
363        for join in expression.args.get("joins") or []:
364            on = join.args.get("on")
365            if on and join.kind in ("SEMI", "ANTI"):
366                subquery = exp.select("1").from_(join.this).where(on)
367                exists = exp.Exists(this=subquery)
368                if join.kind == "ANTI":
369                    exists = exists.not_(copy=False)
370
371                join.pop()
372                expression.where(exists, copy=False)
373
374    return expression
375
376
377def preprocess(
378    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
379) -> t.Callable[[Generator, exp.Expression], str]:
380    """
381    Creates a new transform by chaining a sequence of transformations and converts the resulting
382    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
383    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
384
385    Args:
386        transforms: sequence of transform functions. These will be called in order.
387
388    Returns:
389        Function that can be used as a generator transform.
390    """
391
392    def _to_sql(self, expression: exp.Expression) -> str:
393        expression_type = type(expression)
394
395        expression = transforms[0](expression.copy())
396        for t in transforms[1:]:
397            expression = t(expression)
398
399        _sql_handler = getattr(self, expression.key + "_sql", None)
400        if _sql_handler:
401            return _sql_handler(expression)
402
403        transforms_handler = self.TRANSFORMS.get(type(expression))
404        if transforms_handler:
405            if expression_type is type(expression):
406                if isinstance(expression, exp.Func):
407                    return self.function_fallback_sql(expression)
408
409                # Ensures we don't enter an infinite loop. This can happen when the original expression
410                # has the same type as the final expression and there's no _sql method available for it,
411                # because then it'd re-enter _to_sql.
412                raise ValueError(
413                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
414                )
415
416            return transforms_handler(self, expression)
417
418        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
419
420    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
13def unalias_group(expression: exp.Expression) -> exp.Expression:
14    """
15    Replace references to select aliases in GROUP BY clauses.
16
17    Example:
18        >>> import sqlglot
19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
20        'SELECT a AS b FROM x GROUP BY 1'
21
22    Args:
23        expression: the expression that will be transformed.
24
25    Returns:
26        The transformed expression.
27    """
28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
29        aliased_selects = {
30            e.alias: i
31            for i, e in enumerate(expression.parent.expressions, start=1)
32            if isinstance(e, exp.Alias)
33        }
34
35        for group_by in expression.expressions:
36            if (
37                isinstance(group_by, exp.Column)
38                and not group_by.table
39                and group_by.name in aliased_selects
40            ):
41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
42
43    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
47    """
48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
49
50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
51
52    Args:
53        expression: the expression that will be transformed.
54
55    Returns:
56        The transformed expression.
57    """
58    if (
59        isinstance(expression, exp.Select)
60        and expression.args.get("distinct")
61        and expression.args["distinct"].args.get("on")
62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
63    ):
64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
65        outer_selects = expression.selects
66        row_number = find_new_name(expression.named_selects, "_row_number")
67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
68        order = expression.args.get("order")
69
70        if order:
71            window.set("order", order.pop().copy())
72        else:
73            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
74
75        window = exp.alias_(window, row_number)
76        expression.select(window, copy=False)
77
78        return (
79            exp.select(*outer_selects)
80            .from_(expression.subquery("_t"))
81            .where(exp.column(row_number).eq(1))
82        )
83
84    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 87def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 88    """
 89    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 90
 91    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 92    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 93
 94    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 95    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 96    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 97    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 98    """
 99    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
100        taken = set(expression.named_selects)
101        for select in expression.selects:
102            if not select.alias_or_name:
103                alias = find_new_name(taken, "_c")
104                select.replace(exp.alias_(select, alias))
105                taken.add(alias)
106
107        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
108        qualify_filters = expression.args["qualify"].pop().this
109
110        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
111        for expr in qualify_filters.find_all(select_candidates):
112            if isinstance(expr, exp.Window):
113                alias = find_new_name(expression.named_selects, "_w")
114                expression.select(exp.alias_(expr, alias), copy=False)
115                column = exp.column(alias)
116
117                if isinstance(expr.parent, exp.Qualify):
118                    qualify_filters = column
119                else:
120                    expr.replace(column)
121            elif expr.name not in expression.named_selects:
122                expression.select(expr.copy(), copy=False)
123
124        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
125
126    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
130    """
131    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
132    other expressions. This transforms removes the precision from parameterized types in expressions.
133    """
134    for node in expression.find_all(exp.DataType):
135        node.set(
136            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
137        )
138
139    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
143    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
144    if isinstance(expression, exp.Select):
145        for join in expression.args.get("joins") or []:
146            unnest = join.this
147
148            if isinstance(unnest, exp.Unnest):
149                alias = unnest.args.get("alias")
150                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
151
152                expression.args["joins"].remove(join)
153
154                for e, column in zip(unnest.expressions, alias.columns if alias else []):
155                    expression.append(
156                        "laterals",
157                        exp.Lateral(
158                            this=udtf(this=e),
159                            view=True,
160                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
161                        ),
162                    )
163
164    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
167def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
168    """Convert explode/posexplode into unnest (used in hive -> presto)."""
169
170    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
171        if isinstance(expression, exp.Select):
172            from sqlglot.optimizer.scope import Scope
173
174            taken_select_names = set(expression.named_selects)
175            taken_source_names = {name for name, _ in Scope(expression).references}
176
177            def new_name(names: t.Set[str], name: str) -> str:
178                name = find_new_name(names, name)
179                names.add(name)
180                return name
181
182            arrays: t.List[exp.Condition] = []
183            series_alias = new_name(taken_select_names, "pos")
184            series = exp.alias_(
185                exp.Unnest(
186                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
187                ),
188                new_name(taken_source_names, "_u"),
189                table=[series_alias],
190            )
191
192            # we use list here because expression.selects is mutated inside the loop
193            for select in expression.selects.copy():
194                explode = select.find(exp.Explode)
195
196                if explode:
197                    pos_alias = ""
198                    explode_alias = ""
199
200                    if isinstance(select, exp.Alias):
201                        explode_alias = select.alias
202                        alias = select
203                    elif isinstance(select, exp.Aliases):
204                        pos_alias = select.aliases[0].name
205                        explode_alias = select.aliases[1].name
206                        alias = select.replace(exp.alias_(select.this, "", copy=False))
207                    else:
208                        alias = select.replace(exp.alias_(select, ""))
209                        explode = alias.find(exp.Explode)
210                        assert explode
211
212                    is_posexplode = isinstance(explode, exp.Posexplode)
213                    explode_arg = explode.this
214
215                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
216                    if isinstance(explode_arg, exp.Column):
217                        taken_select_names.add(explode_arg.output_name)
218
219                    unnest_source_alias = new_name(taken_source_names, "_u")
220
221                    if not explode_alias:
222                        explode_alias = new_name(taken_select_names, "col")
223
224                        if is_posexplode:
225                            pos_alias = new_name(taken_select_names, "pos")
226
227                    if not pos_alias:
228                        pos_alias = new_name(taken_select_names, "pos")
229
230                    alias.set("alias", exp.to_identifier(explode_alias))
231
232                    column = exp.If(
233                        this=exp.column(series_alias).eq(exp.column(pos_alias)),
234                        true=exp.column(explode_alias),
235                    )
236
237                    explode.replace(column)
238
239                    if is_posexplode:
240                        expressions = expression.expressions
241                        expressions.insert(
242                            expressions.index(alias) + 1,
243                            exp.If(
244                                this=exp.column(series_alias).eq(exp.column(pos_alias)),
245                                true=exp.column(pos_alias),
246                            ).as_(pos_alias),
247                        )
248                        expression.set("expressions", expressions)
249
250                    if not arrays:
251                        if expression.args.get("from"):
252                            expression.join(series, copy=False)
253                        else:
254                            expression.from_(series, copy=False)
255
256                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
257                    arrays.append(size)
258
259                    # trino doesn't support left join unnest with on conditions
260                    # if it did, this would be much simpler
261                    expression.join(
262                        exp.alias_(
263                            exp.Unnest(
264                                expressions=[explode_arg.copy()],
265                                offset=exp.to_identifier(pos_alias),
266                            ),
267                            unnest_source_alias,
268                            table=[explode_alias],
269                        ),
270                        join_type="CROSS",
271                        copy=False,
272                    )
273
274                    if index_offset != 1:
275                        size = size - 1
276
277                    expression.where(
278                        exp.column(series_alias)
279                        .eq(exp.column(pos_alias))
280                        .or_(
281                            (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
282                        ),
283                        copy=False,
284                    )
285
286            if arrays:
287                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
288
289                if index_offset != 1:
290                    end = end - (1 - index_offset)
291                series.expressions[0].set("end", end)
292
293        return expression
294
295    return _explode_to_unnest

Convert explode/posexplode into unnest (used in hive -> presto).

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
301def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
302    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
303    if (
304        isinstance(expression, PERCENTILES)
305        and not isinstance(expression.parent, exp.WithinGroup)
306        and expression.expression
307    ):
308        column = expression.this.pop()
309        expression.set("this", expression.expression.pop())
310        order = exp.Order(expressions=[exp.Ordered(this=column)])
311        expression = exp.WithinGroup(this=expression, expression=order)
312
313    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
316def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
317    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
318    if (
319        isinstance(expression, exp.WithinGroup)
320        and isinstance(expression.this, PERCENTILES)
321        and isinstance(expression.expression, exp.Order)
322    ):
323        quantile = expression.this.this
324        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
325        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
326
327    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
330def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
331    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
332    if isinstance(expression, exp.With) and expression.recursive:
333        next_name = name_sequence("_c_")
334
335        for cte in expression.expressions:
336            if not cte.args["alias"].columns:
337                query = cte.this
338                if isinstance(query, exp.Union):
339                    query = query.this
340
341                cte.args["alias"].set(
342                    "columns",
343                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
344                )
345
346    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
349def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
350    """Replace 'epoch' in casts by the equivalent date literal."""
351    if (
352        isinstance(expression, (exp.Cast, exp.TryCast))
353        and expression.name.lower() == "epoch"
354        and expression.to.this in exp.DataType.TEMPORAL_TYPES
355    ):
356        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
357
358    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
361def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
362    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
363    if isinstance(expression, exp.Select):
364        for join in expression.args.get("joins") or []:
365            on = join.args.get("on")
366            if on and join.kind in ("SEMI", "ANTI"):
367                subquery = exp.select("1").from_(join.this).where(on)
368                exists = exp.Exists(this=subquery)
369                if join.kind == "ANTI":
370                    exists = exists.not_(copy=False)
371
372                join.pop()
373                expression.where(exists, copy=False)
374
375    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
378def preprocess(
379    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
380) -> t.Callable[[Generator, exp.Expression], str]:
381    """
382    Creates a new transform by chaining a sequence of transformations and converts the resulting
383    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
384    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
385
386    Args:
387        transforms: sequence of transform functions. These will be called in order.
388
389    Returns:
390        Function that can be used as a generator transform.
391    """
392
393    def _to_sql(self, expression: exp.Expression) -> str:
394        expression_type = type(expression)
395
396        expression = transforms[0](expression.copy())
397        for t in transforms[1:]:
398            expression = t(expression)
399
400        _sql_handler = getattr(self, expression.key + "_sql", None)
401        if _sql_handler:
402            return _sql_handler(expression)
403
404        transforms_handler = self.TRANSFORMS.get(type(expression))
405        if transforms_handler:
406            if expression_type is type(expression):
407                if isinstance(expression, exp.Func):
408                    return self.function_fallback_sql(expression)
409
410                # Ensures we don't enter an infinite loop. This can happen when the original expression
411                # has the same type as the final expression and there's no _sql method available for it,
412                # because then it'd re-enter _to_sql.
413                raise ValueError(
414                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
415                )
416
417            return transforms_handler(self, expression)
418
419        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
420
421    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.