sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 alias = next_alias_name() 45 46 if ( 47 not predicate 48 or parent_select is not predicate.parent_select 49 or not parent_select.args.get("from") 50 ): 51 return 52 53 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 54 55 # This subquery returns a scalar and can just be converted to a cross join 56 if not isinstance(predicate, (exp.In, exp.Any)): 57 column = exp.column(select.selects[0].alias_or_name, alias) 58 59 clause_parent_select = clause.parent_select if clause else None 60 61 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 62 (not clause or clause_parent_select is not parent_select) 63 and ( 64 parent_select.args.get("group") 65 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 66 ) 67 ): 68 column = exp.Max(this=column) 69 elif not isinstance(select.parent, exp.Subquery): 70 return 71 72 _replace(select.parent, column) 73 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 74 return 75 76 if select.find(exp.Limit, exp.Offset): 77 return 78 79 if isinstance(predicate, exp.Any): 80 predicate = predicate.find_ancestor(exp.EQ) 81 82 if not predicate or parent_select is not predicate.parent_select: 83 return 84 85 column = _other_operand(predicate) 86 value = select.selects[0] 87 88 join_key = exp.column(value.alias, alias) 89 join_key_not_null = join_key.is_(exp.null()).not_() 90 91 if isinstance(clause, exp.Join): 92 _replace(predicate, exp.true()) 93 parent_select.where(join_key_not_null, copy=False) 94 else: 95 _replace(predicate, join_key_not_null) 96 97 group = select.args.get("group") 98 99 if group: 100 if {value.this} != set(group.expressions): 101 select = ( 102 exp.select(exp.column(value.alias, "_q")) 103 .from_(select.subquery("_q", copy=False), copy=False) 104 .group_by(exp.column(value.alias, "_q"), copy=False) 105 ) 106 else: 107 select = select.group_by(value.this, copy=False) 108 109 parent_select.join( 110 select, 111 on=column.eq(join_key), 112 join_type="LEFT", 113 join_alias=alias, 114 copy=False, 115 ) 116 117 118def decorrelate(select, parent_select, external_columns, next_alias_name): 119 where = select.args.get("where") 120 121 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 122 return 123 124 table_alias = next_alias_name() 125 keys = [] 126 127 # for all external columns in the where statement, find the relevant predicate 128 # keys to convert it into a join 129 for column in external_columns: 130 if column.find_ancestor(exp.Where) is not where: 131 return 132 133 predicate = column.find_ancestor(exp.Predicate) 134 135 if not predicate or predicate.find_ancestor(exp.Where) is not where: 136 return 137 138 if isinstance(predicate, exp.Binary): 139 key = ( 140 predicate.right 141 if any(node is column for node, *_ in predicate.left.walk()) 142 else predicate.left 143 ) 144 else: 145 return 146 147 keys.append((key, column, predicate)) 148 149 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 150 return 151 152 is_subquery_projection = any( 153 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 154 ) 155 156 value = select.selects[0] 157 key_aliases = {} 158 group_by = [] 159 160 for key, _, predicate in keys: 161 # if we filter on the value of the subquery, it needs to be unique 162 if key == value.this: 163 key_aliases[key] = value.alias 164 group_by.append(key) 165 else: 166 if key not in key_aliases: 167 key_aliases[key] = next_alias_name() 168 # all predicates that are equalities must also be in the unique 169 # so that we don't do a many to many join 170 if isinstance(predicate, exp.EQ) and key not in group_by: 171 group_by.append(key) 172 173 parent_predicate = select.find_ancestor(exp.Predicate) 174 175 # if the value of the subquery is not an agg or a key, we need to collect it into an array 176 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 177 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 178 if not value.find(exp.AggFunc) and value.this not in group_by: 179 select.select( 180 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 181 append=False, 182 copy=False, 183 ) 184 185 # exists queries should not have any selects as it only checks if there are any rows 186 # all selects will be added by the optimizer and only used for join keys 187 if isinstance(parent_predicate, exp.Exists): 188 select.args["expressions"] = [] 189 190 for key, alias in key_aliases.items(): 191 if key in group_by: 192 # add all keys to the projections of the subquery 193 # so that we can use it as a join key 194 if isinstance(parent_predicate, exp.Exists) or key != value.this: 195 select.select(f"{key} AS {alias}", copy=False) 196 else: 197 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 198 199 alias = exp.column(value.alias, table_alias) 200 other = _other_operand(parent_predicate) 201 202 if isinstance(parent_predicate, exp.Exists): 203 alias = exp.column(list(key_aliases.values())[0], table_alias) 204 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 205 elif isinstance(parent_predicate, exp.All): 206 parent_predicate = _replace( 207 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 208 ) 209 elif isinstance(parent_predicate, exp.Any): 210 if value.this in group_by: 211 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 212 else: 213 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 214 elif isinstance(parent_predicate, exp.In): 215 if value.this in group_by: 216 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 217 else: 218 parent_predicate = _replace( 219 parent_predicate, 220 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 221 ) 222 else: 223 if is_subquery_projection: 224 alias = exp.alias_(alias, select.parent.alias) 225 226 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 227 # by transforming all counts into 0 and using that as the coalesced value 228 if value.find(exp.Count): 229 230 def remove_aggs(node): 231 if isinstance(node, exp.Count): 232 return exp.Literal.number(0) 233 elif isinstance(node, exp.AggFunc): 234 return exp.null() 235 return node 236 237 alias = exp.Coalesce( 238 this=alias, 239 expressions=[value.this.transform(remove_aggs)], 240 ) 241 242 select.parent.replace(alias) 243 244 for key, column, predicate in keys: 245 predicate.replace(exp.true()) 246 nested = exp.column(key_aliases[key], table_alias) 247 248 if is_subquery_projection: 249 key.replace(nested) 250 continue 251 252 if key in group_by: 253 key.replace(nested) 254 elif isinstance(predicate, exp.EQ): 255 parent_predicate = _replace( 256 parent_predicate, 257 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 258 ) 259 else: 260 key.replace(exp.to_identifier("_x")) 261 parent_predicate = _replace( 262 parent_predicate, 263 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 264 ) 265 266 parent_select.join( 267 select.group_by(*group_by, copy=False), 268 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 269 join_type="LEFT", 270 join_alias=table_alias, 271 copy=False, 272 ) 273 274 275def _replace(expression, condition): 276 return expression.replace(exp.condition(condition)) 277 278 279def _other_operand(expression): 280 if isinstance(expression, exp.In): 281 return expression.this 282 283 if isinstance(expression, (exp.Any, exp.All)): 284 return _other_operand(expression.parent) 285 286 if isinstance(expression, exp.Binary): 287 return ( 288 expression.right 289 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 290 else expression.left 291 ) 292 293 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 alias = next_alias_name() 46 47 if ( 48 not predicate 49 or parent_select is not predicate.parent_select 50 or not parent_select.args.get("from") 51 ): 52 return 53 54 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 55 56 # This subquery returns a scalar and can just be converted to a cross join 57 if not isinstance(predicate, (exp.In, exp.Any)): 58 column = exp.column(select.selects[0].alias_or_name, alias) 59 60 clause_parent_select = clause.parent_select if clause else None 61 62 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 63 (not clause or clause_parent_select is not parent_select) 64 and ( 65 parent_select.args.get("group") 66 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 67 ) 68 ): 69 column = exp.Max(this=column) 70 elif not isinstance(select.parent, exp.Subquery): 71 return 72 73 _replace(select.parent, column) 74 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 75 return 76 77 if select.find(exp.Limit, exp.Offset): 78 return 79 80 if isinstance(predicate, exp.Any): 81 predicate = predicate.find_ancestor(exp.EQ) 82 83 if not predicate or parent_select is not predicate.parent_select: 84 return 85 86 column = _other_operand(predicate) 87 value = select.selects[0] 88 89 join_key = exp.column(value.alias, alias) 90 join_key_not_null = join_key.is_(exp.null()).not_() 91 92 if isinstance(clause, exp.Join): 93 _replace(predicate, exp.true()) 94 parent_select.where(join_key_not_null, copy=False) 95 else: 96 _replace(predicate, join_key_not_null) 97 98 group = select.args.get("group") 99 100 if group: 101 if {value.this} != set(group.expressions): 102 select = ( 103 exp.select(exp.column(value.alias, "_q")) 104 .from_(select.subquery("_q", copy=False), copy=False) 105 .group_by(exp.column(value.alias, "_q"), copy=False) 106 ) 107 else: 108 select = select.group_by(value.this, copy=False) 109 110 parent_select.join( 111 select, 112 on=column.eq(join_key), 113 join_type="LEFT", 114 join_alias=alias, 115 copy=False, 116 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
119def decorrelate(select, parent_select, external_columns, next_alias_name): 120 where = select.args.get("where") 121 122 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 123 return 124 125 table_alias = next_alias_name() 126 keys = [] 127 128 # for all external columns in the where statement, find the relevant predicate 129 # keys to convert it into a join 130 for column in external_columns: 131 if column.find_ancestor(exp.Where) is not where: 132 return 133 134 predicate = column.find_ancestor(exp.Predicate) 135 136 if not predicate or predicate.find_ancestor(exp.Where) is not where: 137 return 138 139 if isinstance(predicate, exp.Binary): 140 key = ( 141 predicate.right 142 if any(node is column for node, *_ in predicate.left.walk()) 143 else predicate.left 144 ) 145 else: 146 return 147 148 keys.append((key, column, predicate)) 149 150 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 151 return 152 153 is_subquery_projection = any( 154 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 155 ) 156 157 value = select.selects[0] 158 key_aliases = {} 159 group_by = [] 160 161 for key, _, predicate in keys: 162 # if we filter on the value of the subquery, it needs to be unique 163 if key == value.this: 164 key_aliases[key] = value.alias 165 group_by.append(key) 166 else: 167 if key not in key_aliases: 168 key_aliases[key] = next_alias_name() 169 # all predicates that are equalities must also be in the unique 170 # so that we don't do a many to many join 171 if isinstance(predicate, exp.EQ) and key not in group_by: 172 group_by.append(key) 173 174 parent_predicate = select.find_ancestor(exp.Predicate) 175 176 # if the value of the subquery is not an agg or a key, we need to collect it into an array 177 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 178 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 179 if not value.find(exp.AggFunc) and value.this not in group_by: 180 select.select( 181 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 182 append=False, 183 copy=False, 184 ) 185 186 # exists queries should not have any selects as it only checks if there are any rows 187 # all selects will be added by the optimizer and only used for join keys 188 if isinstance(parent_predicate, exp.Exists): 189 select.args["expressions"] = [] 190 191 for key, alias in key_aliases.items(): 192 if key in group_by: 193 # add all keys to the projections of the subquery 194 # so that we can use it as a join key 195 if isinstance(parent_predicate, exp.Exists) or key != value.this: 196 select.select(f"{key} AS {alias}", copy=False) 197 else: 198 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 199 200 alias = exp.column(value.alias, table_alias) 201 other = _other_operand(parent_predicate) 202 203 if isinstance(parent_predicate, exp.Exists): 204 alias = exp.column(list(key_aliases.values())[0], table_alias) 205 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 206 elif isinstance(parent_predicate, exp.All): 207 parent_predicate = _replace( 208 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 209 ) 210 elif isinstance(parent_predicate, exp.Any): 211 if value.this in group_by: 212 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 213 else: 214 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 215 elif isinstance(parent_predicate, exp.In): 216 if value.this in group_by: 217 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 218 else: 219 parent_predicate = _replace( 220 parent_predicate, 221 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 222 ) 223 else: 224 if is_subquery_projection: 225 alias = exp.alias_(alias, select.parent.alias) 226 227 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 228 # by transforming all counts into 0 and using that as the coalesced value 229 if value.find(exp.Count): 230 231 def remove_aggs(node): 232 if isinstance(node, exp.Count): 233 return exp.Literal.number(0) 234 elif isinstance(node, exp.AggFunc): 235 return exp.null() 236 return node 237 238 alias = exp.Coalesce( 239 this=alias, 240 expressions=[value.this.transform(remove_aggs)], 241 ) 242 243 select.parent.replace(alias) 244 245 for key, column, predicate in keys: 246 predicate.replace(exp.true()) 247 nested = exp.column(key_aliases[key], table_alias) 248 249 if is_subquery_projection: 250 key.replace(nested) 251 continue 252 253 if key in group_by: 254 key.replace(nested) 255 elif isinstance(predicate, exp.EQ): 256 parent_predicate = _replace( 257 parent_predicate, 258 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 259 ) 260 else: 261 key.replace(exp.to_identifier("_x")) 262 parent_predicate = _replace( 263 parent_predicate, 264 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 265 ) 266 267 parent_select.join( 268 select.group_by(*group_by, copy=False), 269 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 270 join_type="LEFT", 271 join_alias=table_alias, 272 copy=False, 273 )