sqlglot.transforms
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.helper import find_new_name 8 9if t.TYPE_CHECKING: 10 from sqlglot.generator import Generator 11 12 13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression 44 45 46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 73 window = exp.alias_(window, row_number) 74 expression.select(window, copy=False) 75 76 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 77 78 return expression 79 80 81def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 82 """ 83 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 84 85 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 86 https://docs.snowflake.com/en/sql-reference/constructs/qualify 87 88 Some dialects don't support window functions in the WHERE clause, so we need to include them as 89 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 90 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 91 otherwise we won't be able to refer to it in the outer query's WHERE clause. 92 """ 93 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 94 taken = set(expression.named_selects) 95 for select in expression.selects: 96 if not select.alias_or_name: 97 alias = find_new_name(taken, "_c") 98 select.replace(exp.alias_(select, alias)) 99 taken.add(alias) 100 101 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 102 qualify_filters = expression.args["qualify"].pop().this 103 104 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 105 if isinstance(expr, exp.Window): 106 alias = find_new_name(expression.named_selects, "_w") 107 expression.select(exp.alias_(expr, alias), copy=False) 108 column = exp.column(alias) 109 110 if isinstance(expr.parent, exp.Qualify): 111 qualify_filters = column 112 else: 113 expr.replace(column) 114 elif expr.name not in expression.named_selects: 115 expression.select(expr.copy(), copy=False) 116 117 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 118 119 return expression 120 121 122def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 123 """ 124 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 125 other expressions. This transforms removes the precision from parameterized types in expressions. 126 """ 127 for node in expression.find_all(exp.DataType): 128 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 129 130 return expression 131 132 133def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 134 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 135 if isinstance(expression, exp.Select): 136 for join in expression.args.get("joins") or []: 137 unnest = join.this 138 139 if isinstance(unnest, exp.Unnest): 140 alias = unnest.args.get("alias") 141 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 142 143 expression.args["joins"].remove(join) 144 145 for e, column in zip(unnest.expressions, alias.columns if alias else []): 146 expression.append( 147 "laterals", 148 exp.Lateral( 149 this=udtf(this=e), 150 view=True, 151 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 152 ), 153 ) 154 155 return expression 156 157 158def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 159 """Convert explode/posexplode into unnest (used in hive -> presto).""" 160 if isinstance(expression, exp.Select): 161 from sqlglot.optimizer.scope import build_scope 162 163 taken_select_names = set(expression.named_selects) 164 taken_source_names = set(build_scope(expression).selected_sources) 165 166 for select in expression.selects: 167 to_replace = select 168 169 pos_alias = "" 170 explode_alias = "" 171 172 if isinstance(select, exp.Alias): 173 explode_alias = select.alias 174 select = select.this 175 elif isinstance(select, exp.Aliases): 176 pos_alias = select.aliases[0].name 177 explode_alias = select.aliases[1].name 178 select = select.this 179 180 if isinstance(select, (exp.Explode, exp.Posexplode)): 181 is_posexplode = isinstance(select, exp.Posexplode) 182 183 explode_arg = select.this 184 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 185 186 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 187 if isinstance(explode_arg, exp.Column): 188 taken_select_names.add(explode_arg.output_name) 189 190 unnest_source_alias = find_new_name(taken_source_names, "_u") 191 taken_source_names.add(unnest_source_alias) 192 193 if not explode_alias: 194 explode_alias = find_new_name(taken_select_names, "col") 195 taken_select_names.add(explode_alias) 196 197 if is_posexplode: 198 pos_alias = find_new_name(taken_select_names, "pos") 199 taken_select_names.add(pos_alias) 200 201 if is_posexplode: 202 column_names = [explode_alias, pos_alias] 203 to_replace.pop() 204 expression.select(pos_alias, explode_alias, copy=False) 205 else: 206 column_names = [explode_alias] 207 to_replace.replace(exp.column(explode_alias)) 208 209 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 210 211 if not expression.args.get("from"): 212 expression.from_(unnest, copy=False) 213 else: 214 expression.join(unnest, join_type="CROSS", copy=False) 215 216 return expression 217 218 219def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 220 """Remove table refs from columns in when statements.""" 221 if isinstance(expression, exp.Merge): 222 alias = expression.this.args.get("alias") 223 targets = {expression.this.this} 224 if alias: 225 targets.add(alias.this) 226 227 for when in expression.expressions: 228 when.transform( 229 lambda node: exp.column(node.name) 230 if isinstance(node, exp.Column) and node.args.get("table") in targets 231 else node, 232 copy=False, 233 ) 234 235 return expression 236 237 238def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 239 if ( 240 isinstance(expression, exp.WithinGroup) 241 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 242 and isinstance(expression.expression, exp.Order) 243 ): 244 quantile = expression.this.this 245 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 246 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 247 248 return expression 249 250 251def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 252 if isinstance(expression, exp.With) and expression.recursive: 253 sequence = itertools.count() 254 next_name = lambda: f"_c_{next(sequence)}" 255 256 for cte in expression.expressions: 257 if not cte.args["alias"].columns: 258 query = cte.this 259 if isinstance(query, exp.Union): 260 query = query.this 261 262 cte.args["alias"].set( 263 "columns", 264 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 265 ) 266 267 return expression 268 269 270def preprocess( 271 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 272) -> t.Callable[[Generator, exp.Expression], str]: 273 """ 274 Creates a new transform by chaining a sequence of transformations and converts the resulting 275 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 276 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 277 278 Args: 279 transforms: sequence of transform functions. These will be called in order. 280 281 Returns: 282 Function that can be used as a generator transform. 283 """ 284 285 def _to_sql(self, expression: exp.Expression) -> str: 286 expression_type = type(expression) 287 288 expression = transforms[0](expression.copy()) 289 for t in transforms[1:]: 290 expression = t(expression) 291 292 _sql_handler = getattr(self, expression.key + "_sql", None) 293 if _sql_handler: 294 return _sql_handler(expression) 295 296 transforms_handler = self.TRANSFORMS.get(type(expression)) 297 if transforms_handler: 298 # Ensures we don't enter an infinite loop. This can happen when the original expression 299 # has the same type as the final expression and there's no _sql method available for it, 300 # because then it'd re-enter _to_sql. 301 if expression_type is type(expression): 302 raise ValueError( 303 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 304 ) 305 306 return transforms_handler(self, expression) 307 308 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 309 310 return _to_sql
14def unalias_group(expression: exp.Expression) -> exp.Expression: 15 """ 16 Replace references to select aliases in GROUP BY clauses. 17 18 Example: 19 >>> import sqlglot 20 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 21 'SELECT a AS b FROM x GROUP BY 1' 22 23 Args: 24 expression: the expression that will be transformed. 25 26 Returns: 27 The transformed expression. 28 """ 29 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 30 aliased_selects = { 31 e.alias: i 32 for i, e in enumerate(expression.parent.expressions, start=1) 33 if isinstance(e, exp.Alias) 34 } 35 36 for group_by in expression.expressions: 37 if ( 38 isinstance(group_by, exp.Column) 39 and not group_by.table 40 and group_by.name in aliased_selects 41 ): 42 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 43 44 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
47def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 48 """ 49 Convert SELECT DISTINCT ON statements to a subquery with a window function. 50 51 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 52 53 Args: 54 expression: the expression that will be transformed. 55 56 Returns: 57 The transformed expression. 58 """ 59 if ( 60 isinstance(expression, exp.Select) 61 and expression.args.get("distinct") 62 and expression.args["distinct"].args.get("on") 63 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 64 ): 65 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 66 outer_selects = expression.selects 67 row_number = find_new_name(expression.named_selects, "_row_number") 68 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 69 order = expression.args.get("order") 70 71 if order: 72 window.set("order", order.pop().copy()) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') 78 79 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
82def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 83 """ 84 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 85 86 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 87 https://docs.snowflake.com/en/sql-reference/constructs/qualify 88 89 Some dialects don't support window functions in the WHERE clause, so we need to include them as 90 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 91 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 92 otherwise we won't be able to refer to it in the outer query's WHERE clause. 93 """ 94 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 95 taken = set(expression.named_selects) 96 for select in expression.selects: 97 if not select.alias_or_name: 98 alias = find_new_name(taken, "_c") 99 select.replace(exp.alias_(select, alias)) 100 taken.add(alias) 101 102 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 103 qualify_filters = expression.args["qualify"].pop().this 104 105 for expr in qualify_filters.find_all((exp.Window, exp.Column)): 106 if isinstance(expr, exp.Window): 107 alias = find_new_name(expression.named_selects, "_w") 108 expression.select(exp.alias_(expr, alias), copy=False) 109 column = exp.column(alias) 110 111 if isinstance(expr.parent, exp.Qualify): 112 qualify_filters = column 113 else: 114 expr.replace(column) 115 elif expr.name not in expression.named_selects: 116 expression.select(expr.copy(), copy=False) 117 118 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 119 120 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
123def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 124 """ 125 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 126 other expressions. This transforms removes the precision from parameterized types in expressions. 127 """ 128 for node in expression.find_all(exp.DataType): 129 node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) 130 131 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
134def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 135 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 136 if isinstance(expression, exp.Select): 137 for join in expression.args.get("joins") or []: 138 unnest = join.this 139 140 if isinstance(unnest, exp.Unnest): 141 alias = unnest.args.get("alias") 142 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 143 144 expression.args["joins"].remove(join) 145 146 for e, column in zip(unnest.expressions, alias.columns if alias else []): 147 expression.append( 148 "laterals", 149 exp.Lateral( 150 this=udtf(this=e), 151 view=True, 152 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 153 ), 154 ) 155 156 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
159def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 160 """Convert explode/posexplode into unnest (used in hive -> presto).""" 161 if isinstance(expression, exp.Select): 162 from sqlglot.optimizer.scope import build_scope 163 164 taken_select_names = set(expression.named_selects) 165 taken_source_names = set(build_scope(expression).selected_sources) 166 167 for select in expression.selects: 168 to_replace = select 169 170 pos_alias = "" 171 explode_alias = "" 172 173 if isinstance(select, exp.Alias): 174 explode_alias = select.alias 175 select = select.this 176 elif isinstance(select, exp.Aliases): 177 pos_alias = select.aliases[0].name 178 explode_alias = select.aliases[1].name 179 select = select.this 180 181 if isinstance(select, (exp.Explode, exp.Posexplode)): 182 is_posexplode = isinstance(select, exp.Posexplode) 183 184 explode_arg = select.this 185 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 186 187 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 188 if isinstance(explode_arg, exp.Column): 189 taken_select_names.add(explode_arg.output_name) 190 191 unnest_source_alias = find_new_name(taken_source_names, "_u") 192 taken_source_names.add(unnest_source_alias) 193 194 if not explode_alias: 195 explode_alias = find_new_name(taken_select_names, "col") 196 taken_select_names.add(explode_alias) 197 198 if is_posexplode: 199 pos_alias = find_new_name(taken_select_names, "pos") 200 taken_select_names.add(pos_alias) 201 202 if is_posexplode: 203 column_names = [explode_alias, pos_alias] 204 to_replace.pop() 205 expression.select(pos_alias, explode_alias, copy=False) 206 else: 207 column_names = [explode_alias] 208 to_replace.replace(exp.column(explode_alias)) 209 210 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 211 212 if not expression.args.get("from"): 213 expression.from_(unnest, copy=False) 214 else: 215 expression.join(unnest, join_type="CROSS", copy=False) 216 217 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
220def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: 221 """Remove table refs from columns in when statements.""" 222 if isinstance(expression, exp.Merge): 223 alias = expression.this.args.get("alias") 224 targets = {expression.this.this} 225 if alias: 226 targets.add(alias.this) 227 228 for when in expression.expressions: 229 when.transform( 230 lambda node: exp.column(node.name) 231 if isinstance(node, exp.Column) and node.args.get("table") in targets 232 else node, 233 copy=False, 234 ) 235 236 return expression
Remove table refs from columns in when statements.
239def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 240 if ( 241 isinstance(expression, exp.WithinGroup) 242 and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) 243 and isinstance(expression.expression, exp.Order) 244 ): 245 quantile = expression.this.this 246 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 247 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 248 249 return expression
252def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 253 if isinstance(expression, exp.With) and expression.recursive: 254 sequence = itertools.count() 255 next_name = lambda: f"_c_{next(sequence)}" 256 257 for cte in expression.expressions: 258 if not cte.args["alias"].columns: 259 query = cte.this 260 if isinstance(query, exp.Union): 261 query = query.this 262 263 cte.args["alias"].set( 264 "columns", 265 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 266 ) 267 268 return expression
271def preprocess( 272 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 273) -> t.Callable[[Generator, exp.Expression], str]: 274 """ 275 Creates a new transform by chaining a sequence of transformations and converts the resulting 276 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 277 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 278 279 Args: 280 transforms: sequence of transform functions. These will be called in order. 281 282 Returns: 283 Function that can be used as a generator transform. 284 """ 285 286 def _to_sql(self, expression: exp.Expression) -> str: 287 expression_type = type(expression) 288 289 expression = transforms[0](expression.copy()) 290 for t in transforms[1:]: 291 expression = t(expression) 292 293 _sql_handler = getattr(self, expression.key + "_sql", None) 294 if _sql_handler: 295 return _sql_handler(expression) 296 297 transforms_handler = self.TRANSFORMS.get(type(expression)) 298 if transforms_handler: 299 # Ensures we don't enter an infinite loop. This can happen when the original expression 300 # has the same type as the final expression and there's no _sql method available for it, 301 # because then it'd re-enter _to_sql. 302 if expression_type is type(expression): 303 raise ValueError( 304 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 305 ) 306 307 return transforms_handler(self, expression) 308 309 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 310 311 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.