sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 72 window = exp.alias_(window, row_number) 73 expression.select(window, copy=False) 74 75 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 76 77 return expression 78 79 80def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 81 """ 82 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 83 84 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 85 https://docs.snowflake.com/en/sql-reference/constructs/qualify 86 87 Some dialects don't support window functions in the WHERE clause, so we need to include them as 88 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 89 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 90 otherwise we won't be able to refer to it in the outer query's WHERE clause. 91 """ 92 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 93 taken = set(expression.named_selects) 94 for select in expression.selects: 95 if not select.alias_or_name: 96 alias = find_new_name(taken, "_c") 97 select.replace(exp.alias_(select, alias)) 98 taken.add(alias) 99 100 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 101 qualify_filters = expression.args["qualify"].pop().this 102 103 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 104 if isinstance(expr, exp.Window): 105 alias = find_new_name(expression.named_selects, "_w") 106 expression.select(exp.alias_(expr, alias), copy=False) 107 column = exp.column(alias) 108 109 if isinstance(expr.parent, exp.Qualify): 110 qualify_filters = column 111 else: 112 expr.replace(column) 113 elif expr.name not in expression.named_selects: 114 expression.select(expr.copy(), copy=False) 115 116 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 117 118 return expression 119 120 121def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 122 """ 123 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 124 other expressions. This transforms removes the precision from parameterized types in expressions. 125 """ 126 for node in expression.find_all(exp.DataType): 127 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 128 129 return expression 130 131 132def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 133 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 134 if isinstance(expression, exp.Select): 135 for join in expression.args.get("joins") or []: 136 unnest = join.this 137 138 if isinstance(unnest, exp.Unnest): 139 alias = unnest.args.get("alias") 140 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 141 142 expression.args["joins"].remove(join) 143 144 for e, column in zip(unnest.expressions, alias.columns if alias else []): 145 expression.append( 146 "laterals", 147 exp.Lateral( 148 this=udtf(this=e), 149 view=True, 150 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 151 ), 152 ) 153 154 return expression 155 156 157def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 158 """Convert explode/posexplode into unnest (used in hive -> presto).""" 159 if isinstance(expression, exp.Select): 160 from sqlglot.optimizer.scope import build_scope 161 162 taken_select_names = set(expression.named_selects) 163 taken_source_names = set(build_scope(expression).selected_sources) 164 165 for select in expression.selects: 166 to_replace = select 167 168 pos_alias = "" 169 explode_alias = "" 170 171 if isinstance(select, exp.Alias): 172 explode_alias = select.alias 173 select = select.this 174 elif isinstance(select, exp.Aliases): 175 pos_alias = select.aliases[0].name 176 explode_alias = select.aliases[1].name 177 select = select.this 178 179 if isinstance(select, (exp.Explode, exp.Posexplode)): 180 is_posexplode = isinstance(select, exp.Posexplode) 181 182 explode_arg = select.this 183 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 184 185 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 186 if isinstance(explode_arg, exp.Column): 187 taken_select_names.add(explode_arg.output_name) 188 189 unnest_source_alias = find_new_name(taken_source_names, "_u") 190 taken_source_names.add(unnest_source_alias) 191 192 if not explode_alias: 193 explode_alias = find_new_name(taken_select_names, "col") 194 taken_select_names.add(explode_alias) 195 196 if is_posexplode: 197 pos_alias = find_new_name(taken_select_names, "pos") 198 taken_select_names.add(pos_alias) 199 200 if is_posexplode: 201 column_names = [explode_alias, pos_alias] 202 to_replace.pop() 203 expression.select(pos_alias, explode_alias, copy=False) 204 else: 205 column_names = [explode_alias] 206 to_replace.replace(exp.column(explode_alias)) 207 208 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 209 210 if not expression.args.get("from"): 211 expression.from_(unnest, copy=False) 212 else: 213 expression.join(unnest, join_type="CROSS", copy=False) 214 215 return expression 216 217 218def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 219 """Remove table refs from columns in when statements.""" 220 if isinstance(expression, exp.Merge): 221 alias = expression.this.args.get("alias") 222 targets = {expression.this.this} 223 if alias: 224 targets.add(alias.this) 225 226 for when in expression.expressions: 227 when.transform( 228 lambda node: exp.column(node.name) 229 if isinstance(node, exp.Column) and node.args.get("table") in targets 230 else node, 231 copy=False, 232 ) 233 234 return expression 235 236 237def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 238 if ( 239 isinstance(expression, exp.WithinGroup) 240 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 241 and isinstance(expression.expression, exp.Order) 242 ): 243 quantile = expression.this.this 244 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 245 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 246 247 return expression 248 249 250def preprocess( 251 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 252) -> t.Callable[[Generator, exp.Expression], str]: 253 """ 254 Creates a new transform by chaining a sequence of transformations and converts the resulting 255 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 256 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 257 258 Args: 259 transforms: sequence of transform functions. These will be called in order. 260 261 Returns: 262 Function that can be used as a generator transform. 263 """ 264 265 def _to_sql(self, expression: exp.Expression) -> str: 266 expression_type = type(expression) 267 268 expression = transforms[0](expression.copy()) 269 for t in transforms[1:]: 270 expression = t(expression) 271 272 _sql_handler = getattr(self, expression.key + "_sql", None) 273 if _sql_handler: 274 return _sql_handler(expression) 275 276 transforms_handler = self.TRANSFORMS.get(type(expression)) 277 if transforms_handler: 278 # Ensures we don't enter an infinite loop. This can happen when the original expression 279 # has the same type as the final expression and there's no _sql method available for it, 280 # because then it'd re-enter _to_sql. 281 if expression_type is type(expression): 282 raise ValueError( 283 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 284 ) 285 286 return transforms_handler(self, expression) 287 288 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 289 290 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 73 window = exp.alias_(window, row_number) 74 expression.select(window, copy=False) 75 76 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 77 78 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
81def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 82 """ 83 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 84 85 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 86 https://docs.snowflake.com/en/sql-reference/constructs/qualify 87 88 Some dialects don't support window functions in the WHERE clause, so we need to include them as 89 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 90 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 91 otherwise we won't be able to refer to it in the outer query's WHERE clause. 92 """ 93 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 94 taken = set(expression.named_selects) 95 for select in expression.selects: 96 if not select.alias_or_name: 97 alias = find_new_name(taken, "_c") 98 select.replace(exp.alias_(select, alias)) 99 taken.add(alias) 100 101 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 102 qualify_filters = expression.args["qualify"].pop().this 103 104 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 105 if isinstance(expr, exp.Window): 106 alias = find_new_name(expression.named_selects, "_w") 107 expression.select(exp.alias_(expr, alias), copy=False) 108 column = exp.column(alias) 109 110 if isinstance(expr.parent, exp.Qualify): 111 qualify_filters = column 112 else: 113 expr.replace(column) 114 elif expr.name not in expression.named_selects: 115 expression.select(expr.copy(), copy=False) 116 117 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 118 119 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 123 """ 124 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 125 other expressions. This transforms removes the precision from parameterized types in expressions. 126 """ 127 for node in expression.find_all(exp.DataType): 128 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 129 130 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
133def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 134 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 135 if isinstance(expression, exp.Select): 136 for join in expression.args.get("joins") or []: 137 unnest = join.this 138 139 if isinstance(unnest, exp.Unnest): 140 alias = unnest.args.get("alias") 141 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 142 143 expression.args["joins"].remove(join) 144 145 for e, column in zip(unnest.expressions, alias.columns if alias else []): 146 expression.append( 147 "laterals", 148 exp.Lateral( 149 this=udtf(this=e), 150 view=True, 151 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 152 ), 153 ) 154 155 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
158def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 159 """Convert explode/posexplode into unnest (used in hive -> presto).""" 160 if isinstance(expression, exp.Select): 161 from sqlglot.optimizer.scope import build_scope 162 163 taken_select_names = set(expression.named_selects) 164 taken_source_names = set(build_scope(expression).selected_sources) 165 166 for select in expression.selects: 167 to_replace = select 168 169 pos_alias = "" 170 explode_alias = "" 171 172 if isinstance(select, exp.Alias): 173 explode_alias = select.alias 174 select = select.this 175 elif isinstance(select, exp.Aliases): 176 pos_alias = select.aliases[0].name 177 explode_alias = select.aliases[1].name 178 select = select.this 179 180 if isinstance(select, (exp.Explode, exp.Posexplode)): 181 is_posexplode = isinstance(select, exp.Posexplode) 182 183 explode_arg = select.this 184 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 185 186 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 187 if isinstance(explode_arg, exp.Column): 188 taken_select_names.add(explode_arg.output_name) 189 190 unnest_source_alias = find_new_name(taken_source_names, "_u") 191 taken_source_names.add(unnest_source_alias) 192 193 if not explode_alias: 194 explode_alias = find_new_name(taken_select_names, "col") 195 taken_select_names.add(explode_alias) 196 197 if is_posexplode: 198 pos_alias = find_new_name(taken_select_names, "pos") 199 taken_select_names.add(pos_alias) 200 201 if is_posexplode: 202 column_names = [explode_alias, pos_alias] 203 to_replace.pop() 204 expression.select(pos_alias, explode_alias, copy=False) 205 else: 206 column_names = [explode_alias] 207 to_replace.replace(exp.column(explode_alias)) 208 209 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 210 211 if not expression.args.get("from"): 212 expression.from_(unnest, copy=False) 213 else: 214 expression.join(unnest, join_type="CROSS", copy=False) 215 216 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
219def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 220 """Remove table refs from columns in when statements.""" 221 if isinstance(expression, exp.Merge): 222 alias = expression.this.args.get("alias") 223 targets = {expression.this.this} 224 if alias: 225 targets.add(alias.this) 226 227 for when in expression.expressions: 228 when.transform( 229 lambda node: exp.column(node.name) 230 if isinstance(node, exp.Column) and node.args.get("table") in targets 231 else node, 232 copy=False, 233 ) 234 235 return expression
Remove table refs from columns in when statements.
238def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 239 if ( 240 isinstance(expression, exp.WithinGroup) 241 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 242 and isinstance(expression.expression, exp.Order) 243 ): 244 quantile = expression.this.this 245 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 246 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 247 248 return expression
251def preprocess( 252 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 253) -> t.Callable[[Generator, exp.Expression], str]: 254 """ 255 Creates a new transform by chaining a sequence of transformations and converts the resulting 256 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 257 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 258 259 Args: 260 transforms: sequence of transform functions. These will be called in order. 261 262 Returns: 263 Function that can be used as a generator transform. 264 """ 265 266 def _to_sql(self, expression: exp.Expression) -> str: 267 expression_type = type(expression) 268 269 expression = transforms[0](expression.copy()) 270 for t in transforms[1:]: 271 expression = t(expression) 272 273 _sql_handler = getattr(self, expression.key + "_sql", None) 274 if _sql_handler: 275 return _sql_handler(expression) 276 277 transforms_handler = self.TRANSFORMS.get(type(expression)) 278 if transforms_handler: 279 # Ensures we don't enter an infinite loop. This can happen when the original expression 280 # has the same type as the final expression and there's no _sql method available for it, 281 # because then it'd re-enter _to_sql. 282 if expression_type is type(expression): 283 raise ValueError( 284 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 285 ) 286 287 return transforms_handler(self, expression) 288 289 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 290 291 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.