sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop().copy()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects) 79 .from_(expression.subquery("_t")) 80 .where(exp.column(row_number).eq(1)) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. 97 """ 98 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 99 taken = set(expression.named_selects) 100 for select in expression.selects: 101 if not select.alias_or_name: 102 alias = find_new_name(taken, "_c") 103 select.replace(exp.alias_(select, alias)) 104 taken.add(alias) 105 106 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 107 qualify_filters = expression.args["qualify"].pop().this 108 109 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 110 for expr in qualify_filters.find_all(select_candidates): 111 if isinstance(expr, exp.Window): 112 alias = find_new_name(expression.named_selects, "_w") 113 expression.select(exp.alias_(expr, alias), copy=False) 114 column = exp.column(alias) 115 116 if isinstance(expr.parent, exp.Qualify): 117 qualify_filters = column 118 else: 119 expr.replace(column) 120 elif expr.name not in expression.named_selects: 121 expression.select(expr.copy(), copy=False) 122 123 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 124 125 return expression 126 127 128def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 129 """ 130 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 131 other expressions. This transforms removes the precision from parameterized types in expressions. 132 """ 133 for node in expression.find_all(exp.DataType): 134 node.set( 135 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 136 ) 137 138 return expression 139 140 141def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 142 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 143 if isinstance(expression, exp.Select): 144 for join in expression.args.get("joins") or []: 145 unnest = join.this 146 147 if isinstance(unnest, exp.Unnest): 148 alias = unnest.args.get("alias") 149 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 150 151 expression.args["joins"].remove(join) 152 153 for e, column in zip(unnest.expressions, alias.columns if alias else []): 154 expression.append( 155 "laterals", 156 exp.Lateral( 157 this=udtf(this=e), 158 view=True, 159 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 160 ), 161 ) 162 163 return expression 164 165 166def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 167 """Convert explode/posexplode into unnest (used in hive -> presto).""" 168 if isinstance(expression, exp.Select): 169 from sqlglot.optimizer.scope import Scope 170 171 taken_select_names = set(expression.named_selects) 172 taken_source_names = {name for name, _ in Scope(expression).references} 173 174 for select in expression.selects: 175 to_replace = select 176 177 pos_alias = "" 178 explode_alias = "" 179 180 if isinstance(select, exp.Alias): 181 explode_alias = select.alias 182 select = select.this 183 elif isinstance(select, exp.Aliases): 184 pos_alias = select.aliases[0].name 185 explode_alias = select.aliases[1].name 186 select = select.this 187 188 if isinstance(select, (exp.Explode, exp.Posexplode)): 189 is_posexplode = isinstance(select, exp.Posexplode) 190 191 explode_arg = select.this 192 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 193 194 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 195 if isinstance(explode_arg, exp.Column): 196 taken_select_names.add(explode_arg.output_name) 197 198 unnest_source_alias = find_new_name(taken_source_names, "_u") 199 taken_source_names.add(unnest_source_alias) 200 201 if not explode_alias: 202 explode_alias = find_new_name(taken_select_names, "col") 203 taken_select_names.add(explode_alias) 204 205 if is_posexplode: 206 pos_alias = find_new_name(taken_select_names, "pos") 207 taken_select_names.add(pos_alias) 208 209 if is_posexplode: 210 column_names = [explode_alias, pos_alias] 211 to_replace.pop() 212 expression.select(pos_alias, explode_alias, copy=False) 213 else: 214 column_names = [explode_alias] 215 to_replace.replace(exp.column(explode_alias)) 216 217 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 218 219 if not expression.args.get("from"): 220 expression.from_(unnest, copy=False) 221 else: 222 expression.join(unnest, join_type="CROSS", copy=False) 223 224 return expression 225 226 227PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 228 229 230def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 231 if ( 232 isinstance(expression, PERCENTILES) 233 and not isinstance(expression.parent, exp.WithinGroup) 234 and expression.expression 235 ): 236 column = expression.this.pop() 237 expression.set("this", expression.expression.pop()) 238 order = exp.Order(expressions=[exp.Ordered(this=column)]) 239 expression = exp.WithinGroup(this=expression, expression=order) 240 241 return expression 242 243 244def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 245 if ( 246 isinstance(expression, exp.WithinGroup) 247 and isinstance(expression.this, PERCENTILES) 248 and isinstance(expression.expression, exp.Order) 249 ): 250 quantile = expression.this.this 251 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 252 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 253 254 return expression 255 256 257def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 258 if isinstance(expression, exp.With) and expression.recursive: 259 next_name = name_sequence("_c_") 260 261 for cte in expression.expressions: 262 if not cte.args["alias"].columns: 263 query = cte.this 264 if isinstance(query, exp.Union): 265 query = query.this 266 267 cte.args["alias"].set( 268 "columns", 269 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 270 ) 271 272 return expression 273 274 275def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 276 if ( 277 isinstance(expression, (exp.Cast, exp.TryCast)) 278 and expression.name.lower() == "epoch" 279 and expression.to.this in exp.DataType.TEMPORAL_TYPES 280 ): 281 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 282 283 return expression 284 285 286def preprocess( 287 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 288) -> t.Callable[[Generator, exp.Expression], str]: 289 """ 290 Creates a new transform by chaining a sequence of transformations and converts the resulting 291 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 292 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 293 294 Args: 295 transforms: sequence of transform functions. These will be called in order. 296 297 Returns: 298 Function that can be used as a generator transform. 299 """ 300 301 def _to_sql(self, expression: exp.Expression) -> str: 302 expression_type = type(expression) 303 304 expression = transforms[0](expression.copy()) 305 for t in transforms[1:]: 306 expression = t(expression) 307 308 _sql_handler = getattr(self, expression.key + "_sql", None) 309 if _sql_handler: 310 return _sql_handler(expression) 311 312 transforms_handler = self.TRANSFORMS.get(type(expression)) 313 if transforms_handler: 314 if expression_type is type(expression): 315 if isinstance(expression, exp.Func): 316 return self.function_fallback_sql(expression) 317 318 # Ensures we don't enter an infinite loop. This can happen when the original expression 319 # has the same type as the final expression and there's no _sql method available for it, 320 # because then it'd re-enter _to_sql. 321 raise ValueError( 322 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 323 ) 324 325 return transforms_handler(self, expression) 326 327 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 328 329 return _to_sql 330 331 332def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: 333 if isinstance(expression, exp.Timestamp) and not expression.expression: 334 return exp.cast( 335 expression.this, 336 to=exp.DataType.Type.TIMESTAMP, 337 ) 338 return expression
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop().copy()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects) 80 .from_(expression.subquery("_t")) 81 .where(exp.column(row_number).eq(1)) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. 98 """ 99 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 100 taken = set(expression.named_selects) 101 for select in expression.selects: 102 if not select.alias_or_name: 103 alias = find_new_name(taken, "_c") 104 select.replace(exp.alias_(select, alias)) 105 taken.add(alias) 106 107 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 108 qualify_filters = expression.args["qualify"].pop().this 109 110 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 111 for expr in qualify_filters.find_all(select_candidates): 112 if isinstance(expr, exp.Window): 113 alias = find_new_name(expression.named_selects, "_w") 114 expression.select(exp.alias_(expr, alias), copy=False) 115 column = exp.column(alias) 116 117 if isinstance(expr.parent, exp.Qualify): 118 qualify_filters = column 119 else: 120 expr.replace(column) 121 elif expr.name not in expression.named_selects: 122 expression.select(expr.copy(), copy=False) 123 124 return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) 125 126 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause.
129def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 130 """ 131 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 132 other expressions. This transforms removes the precision from parameterized types in expressions. 133 """ 134 for node in expression.find_all(exp.DataType): 135 node.set( 136 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 137 ) 138 139 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
142def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 143 """Convert cross join unnest into lateral view explode (used in presto -> hive).""" 144 if isinstance(expression, exp.Select): 145 for join in expression.args.get("joins") or []: 146 unnest = join.this 147 148 if isinstance(unnest, exp.Unnest): 149 alias = unnest.args.get("alias") 150 udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode 151 152 expression.args["joins"].remove(join) 153 154 for e, column in zip(unnest.expressions, alias.columns if alias else []): 155 expression.append( 156 "laterals", 157 exp.Lateral( 158 this=udtf(this=e), 159 view=True, 160 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 161 ), 162 ) 163 164 return expression
Convert cross join unnest into lateral view explode (used in presto -> hive).
167def explode_to_unnest(expression: exp.Expression) -> exp.Expression: 168 """Convert explode/posexplode into unnest (used in hive -> presto).""" 169 if isinstance(expression, exp.Select): 170 from sqlglot.optimizer.scope import Scope 171 172 taken_select_names = set(expression.named_selects) 173 taken_source_names = {name for name, _ in Scope(expression).references} 174 175 for select in expression.selects: 176 to_replace = select 177 178 pos_alias = "" 179 explode_alias = "" 180 181 if isinstance(select, exp.Alias): 182 explode_alias = select.alias 183 select = select.this 184 elif isinstance(select, exp.Aliases): 185 pos_alias = select.aliases[0].name 186 explode_alias = select.aliases[1].name 187 select = select.this 188 189 if isinstance(select, (exp.Explode, exp.Posexplode)): 190 is_posexplode = isinstance(select, exp.Posexplode) 191 192 explode_arg = select.this 193 unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) 194 195 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 196 if isinstance(explode_arg, exp.Column): 197 taken_select_names.add(explode_arg.output_name) 198 199 unnest_source_alias = find_new_name(taken_source_names, "_u") 200 taken_source_names.add(unnest_source_alias) 201 202 if not explode_alias: 203 explode_alias = find_new_name(taken_select_names, "col") 204 taken_select_names.add(explode_alias) 205 206 if is_posexplode: 207 pos_alias = find_new_name(taken_select_names, "pos") 208 taken_select_names.add(pos_alias) 209 210 if is_posexplode: 211 column_names = [explode_alias, pos_alias] 212 to_replace.pop() 213 expression.select(pos_alias, explode_alias, copy=False) 214 else: 215 column_names = [explode_alias] 216 to_replace.replace(exp.column(explode_alias)) 217 218 unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) 219 220 if not expression.args.get("from"): 221 expression.from_(unnest, copy=False) 222 else: 223 expression.join(unnest, join_type="CROSS", copy=False) 224 225 return expression
Convert explode/posexplode into unnest (used in hive -> presto).
231def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 232 if ( 233 isinstance(expression, PERCENTILES) 234 and not isinstance(expression.parent, exp.WithinGroup) 235 and expression.expression 236 ): 237 column = expression.this.pop() 238 expression.set("this", expression.expression.pop()) 239 order = exp.Order(expressions=[exp.Ordered(this=column)]) 240 expression = exp.WithinGroup(this=expression, expression=order) 241 242 return expression
245def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 246 if ( 247 isinstance(expression, exp.WithinGroup) 248 and isinstance(expression.this, PERCENTILES) 249 and isinstance(expression.expression, exp.Order) 250 ): 251 quantile = expression.this.this 252 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 253 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 254 255 return expression
258def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 259 if isinstance(expression, exp.With) and expression.recursive: 260 next_name = name_sequence("_c_") 261 262 for cte in expression.expressions: 263 if not cte.args["alias"].columns: 264 query = cte.this 265 if isinstance(query, exp.Union): 266 query = query.this 267 268 cte.args["alias"].set( 269 "columns", 270 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 271 ) 272 273 return expression
276def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 277 if ( 278 isinstance(expression, (exp.Cast, exp.TryCast)) 279 and expression.name.lower() == "epoch" 280 and expression.to.this in exp.DataType.TEMPORAL_TYPES 281 ): 282 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 283 284 return expression
287def preprocess( 288 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 289) -> t.Callable[[Generator, exp.Expression], str]: 290 """ 291 Creates a new transform by chaining a sequence of transformations and converts the resulting 292 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 293 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 294 295 Args: 296 transforms: sequence of transform functions. These will be called in order. 297 298 Returns: 299 Function that can be used as a generator transform. 300 """ 301 302 def _to_sql(self, expression: exp.Expression) -> str: 303 expression_type = type(expression) 304 305 expression = transforms[0](expression.copy()) 306 for t in transforms[1:]: 307 expression = t(expression) 308 309 _sql_handler = getattr(self, expression.key + "_sql", None) 310 if _sql_handler: 311 return _sql_handler(expression) 312 313 transforms_handler = self.TRANSFORMS.get(type(expression)) 314 if transforms_handler: 315 if expression_type is type(expression): 316 if isinstance(expression, exp.Func): 317 return self.function_fallback_sql(expression) 318 319 # Ensures we don't enter an infinite loop. This can happen when the original expression 320 # has the same type as the final expression and there's no _sql method available for it, 321 # because then it'd re-enter _to_sql. 322 raise ValueError( 323 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 324 ) 325 326 return transforms_handler(self, expression) 327 328 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 329 330 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.