Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import expressions as exp
  7from sqlglot.helper import find_new_name
  8
  9if t.TYPE_CHECKING:
 10    from sqlglot.generator import Generator
 11
 12
 13def unalias_group(expression: exp.Expression) -> exp.Expression:
 14    """
 15    Replace references to select aliases in GROUP BY clauses.
 16
 17    Example:
 18        >>> import sqlglot
 19        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 20        'SELECT a AS b FROM x GROUP BY 1'
 21
 22    Args:
 23        expression: the expression that will be transformed.
 24
 25    Returns:
 26        The transformed expression.
 27    """
 28    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 29        aliased_selects = {
 30            e.alias: i
 31            for i, e in enumerate(expression.parent.expressions, start=1)
 32            if isinstance(e, exp.Alias)
 33        }
 34
 35        for group_by in expression.expressions:
 36            if (
 37                isinstance(group_by, exp.Column)
 38                and not group_by.table
 39                and group_by.name in aliased_selects
 40            ):
 41                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
 42
 43    return expression
 44
 45
 46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
 47    """
 48    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 49
 50    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 51
 52    Args:
 53        expression: the expression that will be transformed.
 54
 55    Returns:
 56        The transformed expression.
 57    """
 58    if (
 59        isinstance(expression, exp.Select)
 60        and expression.args.get("distinct")
 61        and expression.args["distinct"].args.get("on")
 62        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
 63    ):
 64        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 65        outer_selects = expression.selects
 66        row_number = find_new_name(expression.named_selects, "_row_number")
 67        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 68        order = expression.args.get("order")
 69
 70        if order:
 71            window.set("order", order.pop().copy())
 72
 73        window = exp.alias_(window, row_number)
 74        expression.select(window, copy=False)
 75
 76        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
 77
 78    return expression
 79
 80
 81def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 82    """
 83    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 84
 85    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 86    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 87
 88    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 89    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 90    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 91    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 92    """
 93    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 94        taken = set(expression.named_selects)
 95        for select in expression.selects:
 96            if not select.alias_or_name:
 97                alias = find_new_name(taken, "_c")
 98                select.replace(exp.alias_(select, alias))
 99                taken.add(alias)
100
101        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
102        qualify_filters = expression.args["qualify"].pop().this
103
104        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
105            if isinstance(expr, exp.Window):
106                alias = find_new_name(expression.named_selects, "_w")
107                expression.select(exp.alias_(expr, alias), copy=False)
108                column = exp.column(alias)
109
110                if isinstance(expr.parent, exp.Qualify):
111                    qualify_filters = column
112                else:
113                    expr.replace(column)
114            elif expr.name not in expression.named_selects:
115                expression.select(expr.copy(), copy=False)
116
117        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
118
119    return expression
120
121
122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
123    """
124    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
125    other expressions. This transforms removes the precision from parameterized types in expressions.
126    """
127    for node in expression.find_all(exp.DataType):
128        node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
129
130    return expression
131
132
133def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
134    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
135    if isinstance(expression, exp.Select):
136        for join in expression.args.get("joins") or []:
137            unnest = join.this
138
139            if isinstance(unnest, exp.Unnest):
140                alias = unnest.args.get("alias")
141                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
142
143                expression.args["joins"].remove(join)
144
145                for e, column in zip(unnest.expressions, alias.columns if alias else []):
146                    expression.append(
147                        "laterals",
148                        exp.Lateral(
149                            this=udtf(this=e),
150                            view=True,
151                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
152                        ),
153                    )
154
155    return expression
156
157
158def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
159    """Convert explode/posexplode into unnest (used in hive -> presto)."""
160    if isinstance(expression, exp.Select):
161        from sqlglot.optimizer.scope import build_scope
162
163        taken_select_names = set(expression.named_selects)
164        taken_source_names = set(build_scope(expression).selected_sources)
165
166        for select in expression.selects:
167            to_replace = select
168
169            pos_alias = ""
170            explode_alias = ""
171
172            if isinstance(select, exp.Alias):
173                explode_alias = select.alias
174                select = select.this
175            elif isinstance(select, exp.Aliases):
176                pos_alias = select.aliases[0].name
177                explode_alias = select.aliases[1].name
178                select = select.this
179
180            if isinstance(select, (exp.Explode, exp.Posexplode)):
181                is_posexplode = isinstance(select, exp.Posexplode)
182
183                explode_arg = select.this
184                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
185
186                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
187                if isinstance(explode_arg, exp.Column):
188                    taken_select_names.add(explode_arg.output_name)
189
190                unnest_source_alias = find_new_name(taken_source_names, "_u")
191                taken_source_names.add(unnest_source_alias)
192
193                if not explode_alias:
194                    explode_alias = find_new_name(taken_select_names, "col")
195                    taken_select_names.add(explode_alias)
196
197                    if is_posexplode:
198                        pos_alias = find_new_name(taken_select_names, "pos")
199                        taken_select_names.add(pos_alias)
200
201                if is_posexplode:
202                    column_names = [explode_alias, pos_alias]
203                    to_replace.pop()
204                    expression.select(pos_alias, explode_alias, copy=False)
205                else:
206                    column_names = [explode_alias]
207                    to_replace.replace(exp.column(explode_alias))
208
209                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
210
211                if not expression.args.get("from"):
212                    expression.from_(unnest, copy=False)
213                else:
214                    expression.join(unnest, join_type="CROSS", copy=False)
215
216    return expression
217
218
219def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
220    """Remove table refs from columns in when statements."""
221    if isinstance(expression, exp.Merge):
222        alias = expression.this.args.get("alias")
223        targets = {expression.this.this}
224        if alias:
225            targets.add(alias.this)
226
227        for when in expression.expressions:
228            when.transform(
229                lambda node: exp.column(node.name)
230                if isinstance(node, exp.Column) and node.args.get("table") in targets
231                else node,
232                copy=False,
233            )
234
235    return expression
236
237
238def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
239    if (
240        isinstance(expression, exp.WithinGroup)
241        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
242        and isinstance(expression.expression, exp.Order)
243    ):
244        quantile = expression.this.this
245        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
246        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
247
248    return expression
249
250
251def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
252    if isinstance(expression, exp.With) and expression.recursive:
253        sequence = itertools.count()
254        next_name = lambda: f"_c_{next(sequence)}"
255
256        for cte in expression.expressions:
257            if not cte.args["alias"].columns:
258                query = cte.this
259                if isinstance(query, exp.Union):
260                    query = query.this
261
262                cte.args["alias"].set(
263                    "columns",
264                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
265                )
266
267    return expression
268
269
270def preprocess(
271    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
272) -> t.Callable[[Generator, exp.Expression], str]:
273    """
274    Creates a new transform by chaining a sequence of transformations and converts the resulting
275    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
276    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
277
278    Args:
279        transforms: sequence of transform functions. These will be called in order.
280
281    Returns:
282        Function that can be used as a generator transform.
283    """
284
285    def _to_sql(self, expression: exp.Expression) -> str:
286        expression_type = type(expression)
287
288        expression = transforms[0](expression.copy())
289        for t in transforms[1:]:
290            expression = t(expression)
291
292        _sql_handler = getattr(self, expression.key + "_sql", None)
293        if _sql_handler:
294            return _sql_handler(expression)
295
296        transforms_handler = self.TRANSFORMS.get(type(expression))
297        if transforms_handler:
298            # Ensures we don't enter an infinite loop. This can happen when the original expression
299            # has the same type as the final expression and there's no _sql method available for it,
300            # because then it'd re-enter _to_sql.
301            if expression_type is type(expression):
302                raise ValueError(
303                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
304                )
305
306            return transforms_handler(self, expression)
307
308        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
309
310    return _to_sql
def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
14def unalias_group(expression: exp.Expression) -> exp.Expression:
15    """
16    Replace references to select aliases in GROUP BY clauses.
17
18    Example:
19        >>> import sqlglot
20        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
21        'SELECT a AS b FROM x GROUP BY 1'
22
23    Args:
24        expression: the expression that will be transformed.
25
26    Returns:
27        The transformed expression.
28    """
29    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
30        aliased_selects = {
31            e.alias: i
32            for i, e in enumerate(expression.parent.expressions, start=1)
33            if isinstance(e, exp.Alias)
34        }
35
36        for group_by in expression.expressions:
37            if (
38                isinstance(group_by, exp.Column)
39                and not group_by.table
40                and group_by.name in aliased_selects
41            ):
42                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
43
44    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
47def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
48    """
49    Convert SELECT DISTINCT ON statements to a subquery with a window function.
50
51    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
52
53    Args:
54        expression: the expression that will be transformed.
55
56    Returns:
57        The transformed expression.
58    """
59    if (
60        isinstance(expression, exp.Select)
61        and expression.args.get("distinct")
62        and expression.args["distinct"].args.get("on")
63        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
64    ):
65        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
66        outer_selects = expression.selects
67        row_number = find_new_name(expression.named_selects, "_row_number")
68        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
69        order = expression.args.get("order")
70
71        if order:
72            window.set("order", order.pop().copy())
73
74        window = exp.alias_(window, row_number)
75        expression.select(window, copy=False)
76
77        return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
78
79    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 82def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
 83    """
 84    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 85
 86    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 87    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 88
 89    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 90    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 91    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 92    otherwise we won't be able to refer to it in the outer query's WHERE clause.
 93    """
 94    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 95        taken = set(expression.named_selects)
 96        for select in expression.selects:
 97            if not select.alias_or_name:
 98                alias = find_new_name(taken, "_c")
 99                select.replace(exp.alias_(select, alias))
100                taken.add(alias)
101
102        outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
103        qualify_filters = expression.args["qualify"].pop().this
104
105        for expr in qualify_filters.find_all((exp.Window, exp.Column)):
106            if isinstance(expr, exp.Window):
107                alias = find_new_name(expression.named_selects, "_w")
108                expression.select(exp.alias_(expr, alias), copy=False)
109                column = exp.column(alias)
110
111                if isinstance(expr.parent, exp.Qualify):
112                    qualify_filters = column
113                else:
114                    expr.replace(column)
115            elif expr.name not in expression.named_selects:
116                expression.select(expr.copy(), copy=False)
117
118        return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
119
120    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
123def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
124    """
125    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
126    other expressions. This transforms removes the precision from parameterized types in expressions.
127    """
128    for node in expression.find_all(exp.DataType):
129        node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
130
131    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
134def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
135    """Convert cross join unnest into lateral view explode (used in presto -> hive)."""
136    if isinstance(expression, exp.Select):
137        for join in expression.args.get("joins") or []:
138            unnest = join.this
139
140            if isinstance(unnest, exp.Unnest):
141                alias = unnest.args.get("alias")
142                udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
143
144                expression.args["joins"].remove(join)
145
146                for e, column in zip(unnest.expressions, alias.columns if alias else []):
147                    expression.append(
148                        "laterals",
149                        exp.Lateral(
150                            this=udtf(this=e),
151                            view=True,
152                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
153                        ),
154                    )
155
156    return expression

Convert cross join unnest into lateral view explode (used in presto -> hive).

def explode_to_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
159def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
160    """Convert explode/posexplode into unnest (used in hive -> presto)."""
161    if isinstance(expression, exp.Select):
162        from sqlglot.optimizer.scope import build_scope
163
164        taken_select_names = set(expression.named_selects)
165        taken_source_names = set(build_scope(expression).selected_sources)
166
167        for select in expression.selects:
168            to_replace = select
169
170            pos_alias = ""
171            explode_alias = ""
172
173            if isinstance(select, exp.Alias):
174                explode_alias = select.alias
175                select = select.this
176            elif isinstance(select, exp.Aliases):
177                pos_alias = select.aliases[0].name
178                explode_alias = select.aliases[1].name
179                select = select.this
180
181            if isinstance(select, (exp.Explode, exp.Posexplode)):
182                is_posexplode = isinstance(select, exp.Posexplode)
183
184                explode_arg = select.this
185                unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
186
187                # This ensures that we won't use [POS]EXPLODE's argument as a new selection
188                if isinstance(explode_arg, exp.Column):
189                    taken_select_names.add(explode_arg.output_name)
190
191                unnest_source_alias = find_new_name(taken_source_names, "_u")
192                taken_source_names.add(unnest_source_alias)
193
194                if not explode_alias:
195                    explode_alias = find_new_name(taken_select_names, "col")
196                    taken_select_names.add(explode_alias)
197
198                    if is_posexplode:
199                        pos_alias = find_new_name(taken_select_names, "pos")
200                        taken_select_names.add(pos_alias)
201
202                if is_posexplode:
203                    column_names = [explode_alias, pos_alias]
204                    to_replace.pop()
205                    expression.select(pos_alias, explode_alias, copy=False)
206                else:
207                    column_names = [explode_alias]
208                    to_replace.replace(exp.column(explode_alias))
209
210                unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
211
212                if not expression.args.get("from"):
213                    expression.from_(unnest, copy=False)
214                else:
215                    expression.join(unnest, join_type="CROSS", copy=False)
216
217    return expression

Convert explode/posexplode into unnest (used in hive -> presto).

def remove_target_from_merge( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
220def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
221    """Remove table refs from columns in when statements."""
222    if isinstance(expression, exp.Merge):
223        alias = expression.this.args.get("alias")
224        targets = {expression.this.this}
225        if alias:
226            targets.add(alias.this)
227
228        for when in expression.expressions:
229            when.transform(
230                lambda node: exp.column(node.name)
231                if isinstance(node, exp.Column) and node.args.get("table") in targets
232                else node,
233                copy=False,
234            )
235
236    return expression

Remove table refs from columns in when statements.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
239def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
240    if (
241        isinstance(expression, exp.WithinGroup)
242        and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
243        and isinstance(expression.expression, exp.Order)
244    ):
245        quantile = expression.this.this
246        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
247        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
248
249    return expression
def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
252def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
253    if isinstance(expression, exp.With) and expression.recursive:
254        sequence = itertools.count()
255        next_name = lambda: f"_c_{next(sequence)}"
256
257        for cte in expression.expressions:
258            if not cte.args["alias"].columns:
259                query = cte.this
260                if isinstance(query, exp.Union):
261                    query = query.this
262
263                cte.args["alias"].set(
264                    "columns",
265                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
266                )
267
268    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
271def preprocess(
272    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
273) -> t.Callable[[Generator, exp.Expression], str]:
274    """
275    Creates a new transform by chaining a sequence of transformations and converts the resulting
276    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
277    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
278
279    Args:
280        transforms: sequence of transform functions. These will be called in order.
281
282    Returns:
283        Function that can be used as a generator transform.
284    """
285
286    def _to_sql(self, expression: exp.Expression) -> str:
287        expression_type = type(expression)
288
289        expression = transforms[0](expression.copy())
290        for t in transforms[1:]:
291            expression = t(expression)
292
293        _sql_handler = getattr(self, expression.key + "_sql", None)
294        if _sql_handler:
295            return _sql_handler(expression)
296
297        transforms_handler = self.TRANSFORMS.get(type(expression))
298        if transforms_handler:
299            # Ensures we don't enter an infinite loop. This can happen when the original expression
300            # has the same type as the final expression and there's no _sql method available for it,
301            # because then it'd re-enter _to_sql.
302            if expression_type is type(expression):
303                raise ValueError(
304                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
305                )
306
307            return transforms_handler(self, expression)
308
309        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
310
311    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.