Edit on GitHub

sqlglot.planner

  1from __future__ import annotations
  2
  3import math
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.helper import name_sequence
  8from sqlglot.optimizer.eliminate_joins import join_condition
  9
 10
 11class Plan:
 12    def __init__(self, expression: exp.Expression) -> None:
 13        self.expression = expression.copy()
 14        self.root = Step.from_expression(self.expression)
 15        self._dag: t.Dict[Step, t.Set[Step]] = {}
 16
 17    @property
 18    def dag(self) -> t.Dict[Step, t.Set[Step]]:
 19        if not self._dag:
 20            dag: t.Dict[Step, t.Set[Step]] = {}
 21            nodes = {self.root}
 22
 23            while nodes:
 24                node = nodes.pop()
 25                dag[node] = set()
 26
 27                for dep in node.dependencies:
 28                    dag[node].add(dep)
 29                    nodes.add(dep)
 30
 31            self._dag = dag
 32
 33        return self._dag
 34
 35    @property
 36    def leaves(self) -> t.Iterator[Step]:
 37        return (node for node, deps in self.dag.items() if not deps)
 38
 39    def __repr__(self) -> str:
 40        return f"Plan\n----\n{repr(self.root)}"
 41
 42
 43class Step:
 44    @classmethod
 45    def from_expression(
 46        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 47    ) -> Step:
 48        """
 49        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 50        Note: the expression's tables and subqueries must be aliased for this method to work. For
 51        example, given the following expression:
 52
 53        SELECT
 54          x.a,
 55          SUM(x.b)
 56        FROM x AS x
 57        JOIN y AS y
 58          ON x.a = y.a
 59        GROUP BY x.a
 60
 61        the following DAG is produced (the expression IDs might differ per execution):
 62
 63        - Aggregate: x (4347984624)
 64            Context:
 65              Aggregations:
 66                - SUM(x.b)
 67              Group:
 68                - x.a
 69            Projections:
 70              - x.a
 71              - "x".""
 72            Dependencies:
 73            - Join: x (4347985296)
 74              Context:
 75                y:
 76                On: x.a = y.a
 77              Projections:
 78              Dependencies:
 79              - Scan: x (4347983136)
 80                Context:
 81                  Source: x AS x
 82                Projections:
 83              - Scan: y (4343416624)
 84                Context:
 85                  Source: y AS y
 86                Projections:
 87
 88        Args:
 89            expression: the expression to build the DAG from.
 90            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 91
 92        Returns:
 93            A Step DAG corresponding to `expression`.
 94        """
 95        ctes = ctes or {}
 96        expression = expression.unnest()
 97        with_ = expression.args.get("with")
 98
 99        # CTEs break the mold of scope and introduce themselves to all in the context.
100        if with_:
101            ctes = ctes.copy()
102            for cte in with_.expressions:
103                step = Step.from_expression(cte.this, ctes)
104                step.name = cte.alias
105                ctes[step.name] = step  # type: ignore
106
107        from_ = expression.args.get("from")
108
109        if isinstance(expression, exp.Select) and from_:
110            step = Scan.from_expression(from_.this, ctes)
111        elif isinstance(expression, exp.Union):
112            step = SetOperation.from_expression(expression, ctes)
113        else:
114            step = Scan()
115
116        joins = expression.args.get("joins")
117
118        if joins:
119            join = Join.from_joins(joins, ctes)
120            join.name = step.name
121            join.add_dependency(step)
122            step = join
123
124        projections = []  # final selects in this chain of steps representing a select
125        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
126        aggregations = set()
127        next_operand_name = name_sequence("_a_")
128
129        def extract_agg_operands(expression):
130            agg_funcs = tuple(expression.find_all(exp.AggFunc))
131            if agg_funcs:
132                aggregations.add(expression)
133
134            for agg in agg_funcs:
135                for operand in agg.unnest_operands():
136                    if isinstance(operand, exp.Column):
137                        continue
138                    if operand not in operands:
139                        operands[operand] = next_operand_name()
140
141                    operand.replace(exp.column(operands[operand], quoted=True))
142
143            return bool(agg_funcs)
144
145        for e in expression.expressions:
146            if e.find(exp.AggFunc):
147                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
148                extract_agg_operands(e)
149            else:
150                projections.append(e)
151
152        where = expression.args.get("where")
153
154        if where:
155            step.condition = where.this
156
157        group = expression.args.get("group")
158
159        if group or aggregations:
160            aggregate = Aggregate()
161            aggregate.source = step.name
162            aggregate.name = step.name
163
164            having = expression.args.get("having")
165
166            if having:
167                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
168                    aggregate.condition = exp.column("_h", step.name, quoted=True)
169                else:
170                    aggregate.condition = having.this
171
172            aggregate.operands = tuple(
173                alias(operand, alias_) for operand, alias_ in operands.items()
174            )
175            aggregate.aggregations = list(aggregations)
176
177            # give aggregates names and replace projections with references to them
178            aggregate.group = {
179                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
180            }
181
182            intermediate: t.Dict[str | exp.Expression, str] = {}
183            for k, v in aggregate.group.items():
184                intermediate[v] = k
185                if isinstance(v, exp.Column):
186                    intermediate[v.name] = k
187
188            for projection in projections:
189                for node, *_ in projection.walk():
190                    name = intermediate.get(node)
191                    if name:
192                        node.replace(exp.column(name, step.name))
193
194            if aggregate.condition:
195                for node, *_ in aggregate.condition.walk():
196                    name = intermediate.get(node) or intermediate.get(node.name)
197                    if name:
198                        node.replace(exp.column(name, step.name))
199
200            aggregate.add_dependency(step)
201            step = aggregate
202
203        order = expression.args.get("order")
204
205        if order:
206            if isinstance(step, Aggregate):
207                for ordered in order.expressions:
208                    if ordered.find(exp.AggFunc):
209                        operand_name = next_operand_name()
210                        extract_agg_operands(exp.alias_(ordered.this, operand_name, quoted=True))
211                        ordered.this.replace(exp.column(operand_name, quoted=True))
212
213                step.aggregations = list(aggregations)
214
215            sort = Sort()
216            sort.name = step.name
217            sort.key = order.expressions
218            sort.add_dependency(step)
219            step = sort
220
221        step.projections = projections
222
223        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
224            distinct = Aggregate()
225            distinct.source = step.name
226            distinct.name = step.name
227            distinct.group = {
228                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
229                for e in projections or expression.expressions
230            }
231            distinct.add_dependency(step)
232            step = distinct
233
234        limit = expression.args.get("limit")
235
236        if limit:
237            step.limit = int(limit.text("expression"))
238
239        return step
240
241    def __init__(self) -> None:
242        self.name: t.Optional[str] = None
243        self.dependencies: t.Set[Step] = set()
244        self.dependents: t.Set[Step] = set()
245        self.projections: t.Sequence[exp.Expression] = []
246        self.limit: float = math.inf
247        self.condition: t.Optional[exp.Expression] = None
248
249    def add_dependency(self, dependency: Step) -> None:
250        self.dependencies.add(dependency)
251        dependency.dependents.add(self)
252
253    def __repr__(self) -> str:
254        return self.to_s()
255
256    def to_s(self, level: int = 0) -> str:
257        indent = "  " * level
258        nested = f"{indent}    "
259
260        context = self._to_s(f"{nested}  ")
261
262        if context:
263            context = [f"{nested}Context:"] + context
264
265        lines = [
266            f"{indent}- {self.id}",
267            *context,
268            f"{nested}Projections:",
269        ]
270
271        for expression in self.projections:
272            lines.append(f"{nested}  - {expression.sql()}")
273
274        if self.condition:
275            lines.append(f"{nested}Condition: {self.condition.sql()}")
276
277        if self.limit is not math.inf:
278            lines.append(f"{nested}Limit: {self.limit}")
279
280        if self.dependencies:
281            lines.append(f"{nested}Dependencies:")
282            for dependency in self.dependencies:
283                lines.append("  " + dependency.to_s(level + 1))
284
285        return "\n".join(lines)
286
287    @property
288    def type_name(self) -> str:
289        return self.__class__.__name__
290
291    @property
292    def id(self) -> str:
293        name = self.name
294        name = f" {name}" if name else ""
295        return f"{self.type_name}:{name} ({id(self)})"
296
297    def _to_s(self, _indent: str) -> t.List[str]:
298        return []
299
300
301class Scan(Step):
302    @classmethod
303    def from_expression(
304        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
305    ) -> Step:
306        table = expression
307        alias_ = expression.alias_or_name
308
309        if isinstance(expression, exp.Subquery):
310            table = expression.this
311            step = Step.from_expression(table, ctes)
312            step.name = alias_
313            return step
314
315        step = Scan()
316        step.name = alias_
317        step.source = expression
318        if ctes and table.name in ctes:
319            step.add_dependency(ctes[table.name])
320
321        return step
322
323    def __init__(self) -> None:
324        super().__init__()
325        self.source: t.Optional[exp.Expression] = None
326
327    def _to_s(self, indent: str) -> t.List[str]:
328        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
329
330
331class Join(Step):
332    @classmethod
333    def from_joins(
334        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
335    ) -> Step:
336        step = Join()
337
338        for join in joins:
339            source_key, join_key, condition = join_condition(join)
340            step.joins[join.alias_or_name] = {
341                "side": join.side,  # type: ignore
342                "join_key": join_key,
343                "source_key": source_key,
344                "condition": condition,
345            }
346
347            step.add_dependency(Scan.from_expression(join.this, ctes))
348
349        return step
350
351    def __init__(self) -> None:
352        super().__init__()
353        self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
354
355    def _to_s(self, indent: str) -> t.List[str]:
356        lines = []
357        for name, join in self.joins.items():
358            lines.append(f"{indent}{name}: {join['side']}")
359            if join.get("condition"):
360                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
361        return lines
362
363
364class Aggregate(Step):
365    def __init__(self) -> None:
366        super().__init__()
367        self.aggregations: t.List[exp.Expression] = []
368        self.operands: t.Tuple[exp.Expression, ...] = ()
369        self.group: t.Dict[str, exp.Expression] = {}
370        self.source: t.Optional[str] = None
371
372    def _to_s(self, indent: str) -> t.List[str]:
373        lines = [f"{indent}Aggregations:"]
374
375        for expression in self.aggregations:
376            lines.append(f"{indent}  - {expression.sql()}")
377
378        if self.group:
379            lines.append(f"{indent}Group:")
380            for expression in self.group.values():
381                lines.append(f"{indent}  - {expression.sql()}")
382        if self.condition:
383            lines.append(f"{indent}Having:")
384            lines.append(f"{indent}  - {self.condition.sql()}")
385        if self.operands:
386            lines.append(f"{indent}Operands:")
387            for expression in self.operands:
388                lines.append(f"{indent}  - {expression.sql()}")
389
390        return lines
391
392
393class Sort(Step):
394    def __init__(self) -> None:
395        super().__init__()
396        self.key = None
397
398    def _to_s(self, indent: str) -> t.List[str]:
399        lines = [f"{indent}Key:"]
400
401        for expression in self.key:  # type: ignore
402            lines.append(f"{indent}  - {expression.sql()}")
403
404        return lines
405
406
407class SetOperation(Step):
408    def __init__(
409        self,
410        op: t.Type[exp.Expression],
411        left: str | None,
412        right: str | None,
413        distinct: bool = False,
414    ) -> None:
415        super().__init__()
416        self.op = op
417        self.left = left
418        self.right = right
419        self.distinct = distinct
420
421    @classmethod
422    def from_expression(
423        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
424    ) -> Step:
425        assert isinstance(expression, exp.Union)
426        left = Step.from_expression(expression.left, ctes)
427        right = Step.from_expression(expression.right, ctes)
428        step = cls(
429            op=expression.__class__,
430            left=left.name,
431            right=right.name,
432            distinct=bool(expression.args.get("distinct")),
433        )
434        step.add_dependency(left)
435        step.add_dependency(right)
436        return step
437
438    def _to_s(self, indent: str) -> t.List[str]:
439        lines = []
440        if self.distinct:
441            lines.append(f"{indent}Distinct: {self.distinct}")
442        return lines
443
444    @property
445    def type_name(self) -> str:
446        return self.op.__name__
class Plan:
12class Plan:
13    def __init__(self, expression: exp.Expression) -> None:
14        self.expression = expression.copy()
15        self.root = Step.from_expression(self.expression)
16        self._dag: t.Dict[Step, t.Set[Step]] = {}
17
18    @property
19    def dag(self) -> t.Dict[Step, t.Set[Step]]:
20        if not self._dag:
21            dag: t.Dict[Step, t.Set[Step]] = {}
22            nodes = {self.root}
23
24            while nodes:
25                node = nodes.pop()
26                dag[node] = set()
27
28                for dep in node.dependencies:
29                    dag[node].add(dep)
30                    nodes.add(dep)
31
32            self._dag = dag
33
34        return self._dag
35
36    @property
37    def leaves(self) -> t.Iterator[Step]:
38        return (node for node, deps in self.dag.items() if not deps)
39
40    def __repr__(self) -> str:
41        return f"Plan\n----\n{repr(self.root)}"
Plan(expression: sqlglot.expressions.Expression)
13    def __init__(self, expression: exp.Expression) -> None:
14        self.expression = expression.copy()
15        self.root = Step.from_expression(self.expression)
16        self._dag: t.Dict[Step, t.Set[Step]] = {}
expression
root
leaves: Iterator[sqlglot.planner.Step]
class Step:
 44class Step:
 45    @classmethod
 46    def from_expression(
 47        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 48    ) -> Step:
 49        """
 50        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 51        Note: the expression's tables and subqueries must be aliased for this method to work. For
 52        example, given the following expression:
 53
 54        SELECT
 55          x.a,
 56          SUM(x.b)
 57        FROM x AS x
 58        JOIN y AS y
 59          ON x.a = y.a
 60        GROUP BY x.a
 61
 62        the following DAG is produced (the expression IDs might differ per execution):
 63
 64        - Aggregate: x (4347984624)
 65            Context:
 66              Aggregations:
 67                - SUM(x.b)
 68              Group:
 69                - x.a
 70            Projections:
 71              - x.a
 72              - "x".""
 73            Dependencies:
 74            - Join: x (4347985296)
 75              Context:
 76                y:
 77                On: x.a = y.a
 78              Projections:
 79              Dependencies:
 80              - Scan: x (4347983136)
 81                Context:
 82                  Source: x AS x
 83                Projections:
 84              - Scan: y (4343416624)
 85                Context:
 86                  Source: y AS y
 87                Projections:
 88
 89        Args:
 90            expression: the expression to build the DAG from.
 91            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 92
 93        Returns:
 94            A Step DAG corresponding to `expression`.
 95        """
 96        ctes = ctes or {}
 97        expression = expression.unnest()
 98        with_ = expression.args.get("with")
 99
100        # CTEs break the mold of scope and introduce themselves to all in the context.
101        if with_:
102            ctes = ctes.copy()
103            for cte in with_.expressions:
104                step = Step.from_expression(cte.this, ctes)
105                step.name = cte.alias
106                ctes[step.name] = step  # type: ignore
107
108        from_ = expression.args.get("from")
109
110        if isinstance(expression, exp.Select) and from_:
111            step = Scan.from_expression(from_.this, ctes)
112        elif isinstance(expression, exp.Union):
113            step = SetOperation.from_expression(expression, ctes)
114        else:
115            step = Scan()
116
117        joins = expression.args.get("joins")
118
119        if joins:
120            join = Join.from_joins(joins, ctes)
121            join.name = step.name
122            join.add_dependency(step)
123            step = join
124
125        projections = []  # final selects in this chain of steps representing a select
126        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
127        aggregations = set()
128        next_operand_name = name_sequence("_a_")
129
130        def extract_agg_operands(expression):
131            agg_funcs = tuple(expression.find_all(exp.AggFunc))
132            if agg_funcs:
133                aggregations.add(expression)
134
135            for agg in agg_funcs:
136                for operand in agg.unnest_operands():
137                    if isinstance(operand, exp.Column):
138                        continue
139                    if operand not in operands:
140                        operands[operand] = next_operand_name()
141
142                    operand.replace(exp.column(operands[operand], quoted=True))
143
144            return bool(agg_funcs)
145
146        for e in expression.expressions:
147            if e.find(exp.AggFunc):
148                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
149                extract_agg_operands(e)
150            else:
151                projections.append(e)
152
153        where = expression.args.get("where")
154
155        if where:
156            step.condition = where.this
157
158        group = expression.args.get("group")
159
160        if group or aggregations:
161            aggregate = Aggregate()
162            aggregate.source = step.name
163            aggregate.name = step.name
164
165            having = expression.args.get("having")
166
167            if having:
168                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
169                    aggregate.condition = exp.column("_h", step.name, quoted=True)
170                else:
171                    aggregate.condition = having.this
172
173            aggregate.operands = tuple(
174                alias(operand, alias_) for operand, alias_ in operands.items()
175            )
176            aggregate.aggregations = list(aggregations)
177
178            # give aggregates names and replace projections with references to them
179            aggregate.group = {
180                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
181            }
182
183            intermediate: t.Dict[str | exp.Expression, str] = {}
184            for k, v in aggregate.group.items():
185                intermediate[v] = k
186                if isinstance(v, exp.Column):
187                    intermediate[v.name] = k
188
189            for projection in projections:
190                for node, *_ in projection.walk():
191                    name = intermediate.get(node)
192                    if name:
193                        node.replace(exp.column(name, step.name))
194
195            if aggregate.condition:
196                for node, *_ in aggregate.condition.walk():
197                    name = intermediate.get(node) or intermediate.get(node.name)
198                    if name:
199                        node.replace(exp.column(name, step.name))
200
201            aggregate.add_dependency(step)
202            step = aggregate
203
204        order = expression.args.get("order")
205
206        if order:
207            if isinstance(step, Aggregate):
208                for ordered in order.expressions:
209                    if ordered.find(exp.AggFunc):
210                        operand_name = next_operand_name()
211                        extract_agg_operands(exp.alias_(ordered.this, operand_name, quoted=True))
212                        ordered.this.replace(exp.column(operand_name, quoted=True))
213
214                step.aggregations = list(aggregations)
215
216            sort = Sort()
217            sort.name = step.name
218            sort.key = order.expressions
219            sort.add_dependency(step)
220            step = sort
221
222        step.projections = projections
223
224        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
225            distinct = Aggregate()
226            distinct.source = step.name
227            distinct.name = step.name
228            distinct.group = {
229                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
230                for e in projections or expression.expressions
231            }
232            distinct.add_dependency(step)
233            step = distinct
234
235        limit = expression.args.get("limit")
236
237        if limit:
238            step.limit = int(limit.text("expression"))
239
240        return step
241
242    def __init__(self) -> None:
243        self.name: t.Optional[str] = None
244        self.dependencies: t.Set[Step] = set()
245        self.dependents: t.Set[Step] = set()
246        self.projections: t.Sequence[exp.Expression] = []
247        self.limit: float = math.inf
248        self.condition: t.Optional[exp.Expression] = None
249
250    def add_dependency(self, dependency: Step) -> None:
251        self.dependencies.add(dependency)
252        dependency.dependents.add(self)
253
254    def __repr__(self) -> str:
255        return self.to_s()
256
257    def to_s(self, level: int = 0) -> str:
258        indent = "  " * level
259        nested = f"{indent}    "
260
261        context = self._to_s(f"{nested}  ")
262
263        if context:
264            context = [f"{nested}Context:"] + context
265
266        lines = [
267            f"{indent}- {self.id}",
268            *context,
269            f"{nested}Projections:",
270        ]
271
272        for expression in self.projections:
273            lines.append(f"{nested}  - {expression.sql()}")
274
275        if self.condition:
276            lines.append(f"{nested}Condition: {self.condition.sql()}")
277
278        if self.limit is not math.inf:
279            lines.append(f"{nested}Limit: {self.limit}")
280
281        if self.dependencies:
282            lines.append(f"{nested}Dependencies:")
283            for dependency in self.dependencies:
284                lines.append("  " + dependency.to_s(level + 1))
285
286        return "\n".join(lines)
287
288    @property
289    def type_name(self) -> str:
290        return self.__class__.__name__
291
292    @property
293    def id(self) -> str:
294        name = self.name
295        name = f" {name}" if name else ""
296        return f"{self.type_name}:{name} ({id(self)})"
297
298    def _to_s(self, _indent: str) -> t.List[str]:
299        return []
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, sqlglot.planner.Step]] = None) -> sqlglot.planner.Step:
 45    @classmethod
 46    def from_expression(
 47        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
 48    ) -> Step:
 49        """
 50        Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
 51        Note: the expression's tables and subqueries must be aliased for this method to work. For
 52        example, given the following expression:
 53
 54        SELECT
 55          x.a,
 56          SUM(x.b)
 57        FROM x AS x
 58        JOIN y AS y
 59          ON x.a = y.a
 60        GROUP BY x.a
 61
 62        the following DAG is produced (the expression IDs might differ per execution):
 63
 64        - Aggregate: x (4347984624)
 65            Context:
 66              Aggregations:
 67                - SUM(x.b)
 68              Group:
 69                - x.a
 70            Projections:
 71              - x.a
 72              - "x".""
 73            Dependencies:
 74            - Join: x (4347985296)
 75              Context:
 76                y:
 77                On: x.a = y.a
 78              Projections:
 79              Dependencies:
 80              - Scan: x (4347983136)
 81                Context:
 82                  Source: x AS x
 83                Projections:
 84              - Scan: y (4343416624)
 85                Context:
 86                  Source: y AS y
 87                Projections:
 88
 89        Args:
 90            expression: the expression to build the DAG from.
 91            ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
 92
 93        Returns:
 94            A Step DAG corresponding to `expression`.
 95        """
 96        ctes = ctes or {}
 97        expression = expression.unnest()
 98        with_ = expression.args.get("with")
 99
100        # CTEs break the mold of scope and introduce themselves to all in the context.
101        if with_:
102            ctes = ctes.copy()
103            for cte in with_.expressions:
104                step = Step.from_expression(cte.this, ctes)
105                step.name = cte.alias
106                ctes[step.name] = step  # type: ignore
107
108        from_ = expression.args.get("from")
109
110        if isinstance(expression, exp.Select) and from_:
111            step = Scan.from_expression(from_.this, ctes)
112        elif isinstance(expression, exp.Union):
113            step = SetOperation.from_expression(expression, ctes)
114        else:
115            step = Scan()
116
117        joins = expression.args.get("joins")
118
119        if joins:
120            join = Join.from_joins(joins, ctes)
121            join.name = step.name
122            join.add_dependency(step)
123            step = join
124
125        projections = []  # final selects in this chain of steps representing a select
126        operands = {}  # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
127        aggregations = set()
128        next_operand_name = name_sequence("_a_")
129
130        def extract_agg_operands(expression):
131            agg_funcs = tuple(expression.find_all(exp.AggFunc))
132            if agg_funcs:
133                aggregations.add(expression)
134
135            for agg in agg_funcs:
136                for operand in agg.unnest_operands():
137                    if isinstance(operand, exp.Column):
138                        continue
139                    if operand not in operands:
140                        operands[operand] = next_operand_name()
141
142                    operand.replace(exp.column(operands[operand], quoted=True))
143
144            return bool(agg_funcs)
145
146        for e in expression.expressions:
147            if e.find(exp.AggFunc):
148                projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
149                extract_agg_operands(e)
150            else:
151                projections.append(e)
152
153        where = expression.args.get("where")
154
155        if where:
156            step.condition = where.this
157
158        group = expression.args.get("group")
159
160        if group or aggregations:
161            aggregate = Aggregate()
162            aggregate.source = step.name
163            aggregate.name = step.name
164
165            having = expression.args.get("having")
166
167            if having:
168                if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
169                    aggregate.condition = exp.column("_h", step.name, quoted=True)
170                else:
171                    aggregate.condition = having.this
172
173            aggregate.operands = tuple(
174                alias(operand, alias_) for operand, alias_ in operands.items()
175            )
176            aggregate.aggregations = list(aggregations)
177
178            # give aggregates names and replace projections with references to them
179            aggregate.group = {
180                f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
181            }
182
183            intermediate: t.Dict[str | exp.Expression, str] = {}
184            for k, v in aggregate.group.items():
185                intermediate[v] = k
186                if isinstance(v, exp.Column):
187                    intermediate[v.name] = k
188
189            for projection in projections:
190                for node, *_ in projection.walk():
191                    name = intermediate.get(node)
192                    if name:
193                        node.replace(exp.column(name, step.name))
194
195            if aggregate.condition:
196                for node, *_ in aggregate.condition.walk():
197                    name = intermediate.get(node) or intermediate.get(node.name)
198                    if name:
199                        node.replace(exp.column(name, step.name))
200
201            aggregate.add_dependency(step)
202            step = aggregate
203
204        order = expression.args.get("order")
205
206        if order:
207            if isinstance(step, Aggregate):
208                for ordered in order.expressions:
209                    if ordered.find(exp.AggFunc):
210                        operand_name = next_operand_name()
211                        extract_agg_operands(exp.alias_(ordered.this, operand_name, quoted=True))
212                        ordered.this.replace(exp.column(operand_name, quoted=True))
213
214                step.aggregations = list(aggregations)
215
216            sort = Sort()
217            sort.name = step.name
218            sort.key = order.expressions
219            sort.add_dependency(step)
220            step = sort
221
222        step.projections = projections
223
224        if isinstance(expression, exp.Select) and expression.args.get("distinct"):
225            distinct = Aggregate()
226            distinct.source = step.name
227            distinct.name = step.name
228            distinct.group = {
229                e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name)
230                for e in projections or expression.expressions
231            }
232            distinct.add_dependency(step)
233            step = distinct
234
235        limit = expression.args.get("limit")
236
237        if limit:
238            step.limit = int(limit.text("expression"))
239
240        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

name: Optional[str]
dependencies: Set[sqlglot.planner.Step]
dependents: Set[sqlglot.planner.Step]
projections: Sequence[sqlglot.expressions.Expression]
limit: float
condition: Optional[sqlglot.expressions.Expression]
def add_dependency(self, dependency: sqlglot.planner.Step) -> None:
250    def add_dependency(self, dependency: Step) -> None:
251        self.dependencies.add(dependency)
252        dependency.dependents.add(self)
def to_s(self, level: int = 0) -> str:
257    def to_s(self, level: int = 0) -> str:
258        indent = "  " * level
259        nested = f"{indent}    "
260
261        context = self._to_s(f"{nested}  ")
262
263        if context:
264            context = [f"{nested}Context:"] + context
265
266        lines = [
267            f"{indent}- {self.id}",
268            *context,
269            f"{nested}Projections:",
270        ]
271
272        for expression in self.projections:
273            lines.append(f"{nested}  - {expression.sql()}")
274
275        if self.condition:
276            lines.append(f"{nested}Condition: {self.condition.sql()}")
277
278        if self.limit is not math.inf:
279            lines.append(f"{nested}Limit: {self.limit}")
280
281        if self.dependencies:
282            lines.append(f"{nested}Dependencies:")
283            for dependency in self.dependencies:
284                lines.append("  " + dependency.to_s(level + 1))
285
286        return "\n".join(lines)
type_name: str
id: str
class Scan(Step):
302class Scan(Step):
303    @classmethod
304    def from_expression(
305        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
306    ) -> Step:
307        table = expression
308        alias_ = expression.alias_or_name
309
310        if isinstance(expression, exp.Subquery):
311            table = expression.this
312            step = Step.from_expression(table, ctes)
313            step.name = alias_
314            return step
315
316        step = Scan()
317        step.name = alias_
318        step.source = expression
319        if ctes and table.name in ctes:
320            step.add_dependency(ctes[table.name])
321
322        return step
323
324    def __init__(self) -> None:
325        super().__init__()
326        self.source: t.Optional[exp.Expression] = None
327
328    def _to_s(self, indent: str) -> t.List[str]:
329        return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"]  # type: ignore
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, sqlglot.planner.Step]] = None) -> sqlglot.planner.Step:
303    @classmethod
304    def from_expression(
305        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
306    ) -> Step:
307        table = expression
308        alias_ = expression.alias_or_name
309
310        if isinstance(expression, exp.Subquery):
311            table = expression.this
312            step = Step.from_expression(table, ctes)
313            step.name = alias_
314            return step
315
316        step = Scan()
317        step.name = alias_
318        step.source = expression
319        if ctes and table.name in ctes:
320            step.add_dependency(ctes[table.name])
321
322        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

source: Optional[sqlglot.expressions.Expression]
class Join(Step):
332class Join(Step):
333    @classmethod
334    def from_joins(
335        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
336    ) -> Step:
337        step = Join()
338
339        for join in joins:
340            source_key, join_key, condition = join_condition(join)
341            step.joins[join.alias_or_name] = {
342                "side": join.side,  # type: ignore
343                "join_key": join_key,
344                "source_key": source_key,
345                "condition": condition,
346            }
347
348            step.add_dependency(Scan.from_expression(join.this, ctes))
349
350        return step
351
352    def __init__(self) -> None:
353        super().__init__()
354        self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
355
356    def _to_s(self, indent: str) -> t.List[str]:
357        lines = []
358        for name, join in self.joins.items():
359            lines.append(f"{indent}{name}: {join['side']}")
360            if join.get("condition"):
361                lines.append(f"{indent}On: {join['condition'].sql()}")  # type: ignore
362        return lines
@classmethod
def from_joins( cls, joins: Iterable[sqlglot.expressions.Join], ctes: Optional[Dict[str, sqlglot.planner.Step]] = None) -> sqlglot.planner.Step:
333    @classmethod
334    def from_joins(
335        cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
336    ) -> Step:
337        step = Join()
338
339        for join in joins:
340            source_key, join_key, condition = join_condition(join)
341            step.joins[join.alias_or_name] = {
342                "side": join.side,  # type: ignore
343                "join_key": join_key,
344                "source_key": source_key,
345                "condition": condition,
346            }
347
348            step.add_dependency(Scan.from_expression(join.this, ctes))
349
350        return step
joins: Dict[str, Dict[str, Union[List[str], sqlglot.expressions.Expression]]]
class Aggregate(Step):
365class Aggregate(Step):
366    def __init__(self) -> None:
367        super().__init__()
368        self.aggregations: t.List[exp.Expression] = []
369        self.operands: t.Tuple[exp.Expression, ...] = ()
370        self.group: t.Dict[str, exp.Expression] = {}
371        self.source: t.Optional[str] = None
372
373    def _to_s(self, indent: str) -> t.List[str]:
374        lines = [f"{indent}Aggregations:"]
375
376        for expression in self.aggregations:
377            lines.append(f"{indent}  - {expression.sql()}")
378
379        if self.group:
380            lines.append(f"{indent}Group:")
381            for expression in self.group.values():
382                lines.append(f"{indent}  - {expression.sql()}")
383        if self.condition:
384            lines.append(f"{indent}Having:")
385            lines.append(f"{indent}  - {self.condition.sql()}")
386        if self.operands:
387            lines.append(f"{indent}Operands:")
388            for expression in self.operands:
389                lines.append(f"{indent}  - {expression.sql()}")
390
391        return lines
aggregations: List[sqlglot.expressions.Expression]
operands: Tuple[sqlglot.expressions.Expression, ...]
group: Dict[str, sqlglot.expressions.Expression]
source: Optional[str]
class Sort(Step):
394class Sort(Step):
395    def __init__(self) -> None:
396        super().__init__()
397        self.key = None
398
399    def _to_s(self, indent: str) -> t.List[str]:
400        lines = [f"{indent}Key:"]
401
402        for expression in self.key:  # type: ignore
403            lines.append(f"{indent}  - {expression.sql()}")
404
405        return lines
key
class SetOperation(Step):
408class SetOperation(Step):
409    def __init__(
410        self,
411        op: t.Type[exp.Expression],
412        left: str | None,
413        right: str | None,
414        distinct: bool = False,
415    ) -> None:
416        super().__init__()
417        self.op = op
418        self.left = left
419        self.right = right
420        self.distinct = distinct
421
422    @classmethod
423    def from_expression(
424        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
425    ) -> Step:
426        assert isinstance(expression, exp.Union)
427        left = Step.from_expression(expression.left, ctes)
428        right = Step.from_expression(expression.right, ctes)
429        step = cls(
430            op=expression.__class__,
431            left=left.name,
432            right=right.name,
433            distinct=bool(expression.args.get("distinct")),
434        )
435        step.add_dependency(left)
436        step.add_dependency(right)
437        return step
438
439    def _to_s(self, indent: str) -> t.List[str]:
440        lines = []
441        if self.distinct:
442            lines.append(f"{indent}Distinct: {self.distinct}")
443        return lines
444
445    @property
446    def type_name(self) -> str:
447        return self.op.__name__
SetOperation( op: Type[sqlglot.expressions.Expression], left: str | None, right: str | None, distinct: bool = False)
409    def __init__(
410        self,
411        op: t.Type[exp.Expression],
412        left: str | None,
413        right: str | None,
414        distinct: bool = False,
415    ) -> None:
416        super().__init__()
417        self.op = op
418        self.left = left
419        self.right = right
420        self.distinct = distinct
op
left
right
distinct
@classmethod
def from_expression( cls, expression: sqlglot.expressions.Expression, ctes: Optional[Dict[str, sqlglot.planner.Step]] = None) -> sqlglot.planner.Step:
422    @classmethod
423    def from_expression(
424        cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
425    ) -> Step:
426        assert isinstance(expression, exp.Union)
427        left = Step.from_expression(expression.left, ctes)
428        right = Step.from_expression(expression.right, ctes)
429        step = cls(
430            op=expression.__class__,
431            left=left.name,
432            right=right.name,
433            distinct=bool(expression.args.get("distinct")),
434        )
435        step.add_dependency(left)
436        step.add_dependency(right)
437        return step

Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. Note: the expression's tables and subqueries must be aliased for this method to work. For example, given the following expression:

SELECT x.a, SUM(x.b) FROM x AS x JOIN y AS y ON x.a = y.a GROUP BY x.a

the following DAG is produced (the expression IDs might differ per execution):

  • Aggregate: x (4347984624) Context: Aggregations: - SUM(x.b) Group: - x.a Projections:
    • x.a
    • "x"."" Dependencies:
      • Join: x (4347985296) Context: y: On: x.a = y.a Projections: Dependencies:
    • Scan: x (4347983136) Context: Source: x AS x Projections:
    • Scan: y (4343416624) Context: Source: y AS y Projections:
Arguments:
  • expression: the expression to build the DAG from.
  • ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
Returns:

A Step DAG corresponding to expression.

type_name: str