Edit on GitHub

sqlglot.optimizer.pushdown_predicates

  1from sqlglot import exp
  2from sqlglot.optimizer.normalize import normalized
  3from sqlglot.optimizer.scope import build_scope, find_in_scope
  4from sqlglot.optimizer.simplify import simplify
  5
  6
  7def pushdown_predicates(expression, dialect=None):
  8    """
  9    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
 10
 11    Example:
 12        >>> import sqlglot
 13        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
 14        >>> expression = sqlglot.parse_one(sql)
 15        >>> pushdown_predicates(expression).sql()
 16        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
 17
 18    Args:
 19        expression (sqlglot.Expression): expression to optimize
 20    Returns:
 21        sqlglot.Expression: optimized expression
 22    """
 23    root = build_scope(expression)
 24
 25    if root:
 26        scope_ref_count = root.ref_count()
 27
 28        for scope in reversed(list(root.traverse())):
 29            select = scope.expression
 30            where = select.args.get("where")
 31            if where:
 32                selected_sources = scope.selected_sources
 33                # a right join can only push down to itself and not the source FROM table
 34                for k, (node, source) in selected_sources.items():
 35                    parent = node.find_ancestor(exp.Join, exp.From)
 36                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
 37                        selected_sources = {k: (node, source)}
 38                        break
 39                pushdown(where.this, selected_sources, scope_ref_count, dialect)
 40
 41            # joins should only pushdown into itself, not to other joins
 42            # so we limit the selected sources to only itself
 43            for join in select.args.get("joins") or []:
 44                name = join.alias_or_name
 45                if name in scope.selected_sources:
 46                    pushdown(
 47                        join.args.get("on"),
 48                        {name: scope.selected_sources[name]},
 49                        scope_ref_count,
 50                        dialect,
 51                    )
 52
 53    return expression
 54
 55
 56def pushdown(condition, sources, scope_ref_count, dialect):
 57    if not condition:
 58        return
 59
 60    condition = condition.replace(simplify(condition, dialect=dialect))
 61    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
 62
 63    predicates = list(
 64        condition.flatten()
 65        if isinstance(condition, exp.And if cnf_like else exp.Or)
 66        else [condition]
 67    )
 68
 69    if cnf_like:
 70        pushdown_cnf(predicates, sources, scope_ref_count)
 71    else:
 72        pushdown_dnf(predicates, sources, scope_ref_count)
 73
 74
 75def pushdown_cnf(predicates, scope, scope_ref_count):
 76    """
 77    If the predicates are in CNF like form, we can simply replace each block in the parent.
 78    """
 79    for predicate in predicates:
 80        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
 81            if isinstance(node, exp.Join):
 82                predicate.replace(exp.true())
 83                node.on(predicate, copy=False)
 84                break
 85            if isinstance(node, exp.Select):
 86                predicate.replace(exp.true())
 87                inner_predicate = replace_aliases(node, predicate)
 88                if find_in_scope(inner_predicate, exp.AggFunc):
 89                    node.having(inner_predicate, copy=False)
 90                else:
 91                    node.where(inner_predicate, copy=False)
 92
 93
 94def pushdown_dnf(predicates, scope, scope_ref_count):
 95    """
 96    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 97    Additionally, we can't remove predicates from their original form.
 98    """
 99    # find all the tables that can be pushdown too
100    # these are tables that are referenced in all blocks of a DNF
101    # (a.x AND b.x) OR (a.y AND c.y)
102    # only table a can be push down
103    pushdown_tables = set()
104
105    for a in predicates:
106        a_tables = exp.column_table_names(a)
107
108        for b in predicates:
109            a_tables &= exp.column_table_names(b)
110
111        pushdown_tables.update(a_tables)
112
113    conditions = {}
114
115    # pushdown all predicates to their respective nodes
116    for table in sorted(pushdown_tables):
117        for predicate in predicates:
118            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
119
120            if table not in nodes:
121                continue
122
123            conditions[table] = (
124                exp.or_(conditions[table], predicate) if table in conditions else predicate
125            )
126
127        for name, node in nodes.items():
128            if name not in conditions:
129                continue
130
131            predicate = conditions[name]
132
133            if isinstance(node, exp.Join):
134                node.on(predicate, copy=False)
135            elif isinstance(node, exp.Select):
136                inner_predicate = replace_aliases(node, predicate)
137                if find_in_scope(inner_predicate, exp.AggFunc):
138                    node.having(inner_predicate, copy=False)
139                else:
140                    node.where(inner_predicate, copy=False)
141
142
143def nodes_for_predicate(predicate, sources, scope_ref_count):
144    nodes = {}
145    tables = exp.column_table_names(predicate)
146    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
147
148    for table in sorted(tables):
149        node, source = sources.get(table) or (None, None)
150
151        # if the predicate is in a where statement we can try to push it down
152        # we want to find the root join or from statement
153        if node and where_condition:
154            node = node.find_ancestor(exp.Join, exp.From)
155
156        # a node can reference a CTE which should be pushed down
157        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
158            with_ = source.parent.expression.args.get("with")
159            if with_ and with_.recursive:
160                return {}
161            node = source.expression
162
163        if isinstance(node, exp.Join):
164            if node.side and node.side != "RIGHT":
165                return {}
166            nodes[table] = node
167        elif isinstance(node, exp.Select) and len(tables) == 1:
168            # We can't push down window expressions
169            has_window_expression = any(
170                select for select in node.selects if select.find(exp.Window)
171            )
172            # we can't push down predicates to select statements if they are referenced in
173            # multiple places.
174            if (
175                not node.args.get("group")
176                and scope_ref_count[id(source)] < 2
177                and not has_window_expression
178            ):
179                nodes[table] = node
180    return nodes
181
182
183def replace_aliases(source, predicate):
184    aliases = {}
185
186    for select in source.selects:
187        if isinstance(select, exp.Alias):
188            aliases[select.alias] = select.this
189        else:
190            aliases[select.name] = select
191
192    def _replace_alias(column):
193        if isinstance(column, exp.Column) and column.name in aliases:
194            return aliases[column.name].copy()
195        return column
196
197    return predicate.transform(_replace_alias)
def pushdown_predicates(expression, dialect=None):
 8def pushdown_predicates(expression, dialect=None):
 9    """
10    Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
12    Example:
13        >>> import sqlglot
14        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
15        >>> expression = sqlglot.parse_one(sql)
16        >>> pushdown_predicates(expression).sql()
17        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
18
19    Args:
20        expression (sqlglot.Expression): expression to optimize
21    Returns:
22        sqlglot.Expression: optimized expression
23    """
24    root = build_scope(expression)
25
26    if root:
27        scope_ref_count = root.ref_count()
28
29        for scope in reversed(list(root.traverse())):
30            select = scope.expression
31            where = select.args.get("where")
32            if where:
33                selected_sources = scope.selected_sources
34                # a right join can only push down to itself and not the source FROM table
35                for k, (node, source) in selected_sources.items():
36                    parent = node.find_ancestor(exp.Join, exp.From)
37                    if isinstance(parent, exp.Join) and parent.side == "RIGHT":
38                        selected_sources = {k: (node, source)}
39                        break
40                pushdown(where.this, selected_sources, scope_ref_count, dialect)
41
42            # joins should only pushdown into itself, not to other joins
43            # so we limit the selected sources to only itself
44            for join in select.args.get("joins") or []:
45                name = join.alias_or_name
46                if name in scope.selected_sources:
47                    pushdown(
48                        join.args.get("on"),
49                        {name: scope.selected_sources[name]},
50                        scope_ref_count,
51                        dialect,
52                    )
53
54    return expression

Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
Returns:

sqlglot.Expression: optimized expression

def pushdown(condition, sources, scope_ref_count, dialect):
57def pushdown(condition, sources, scope_ref_count, dialect):
58    if not condition:
59        return
60
61    condition = condition.replace(simplify(condition, dialect=dialect))
62    cnf_like = normalized(condition) or not normalized(condition, dnf=True)
63
64    predicates = list(
65        condition.flatten()
66        if isinstance(condition, exp.And if cnf_like else exp.Or)
67        else [condition]
68    )
69
70    if cnf_like:
71        pushdown_cnf(predicates, sources, scope_ref_count)
72    else:
73        pushdown_dnf(predicates, sources, scope_ref_count)
def pushdown_cnf(predicates, scope, scope_ref_count):
76def pushdown_cnf(predicates, scope, scope_ref_count):
77    """
78    If the predicates are in CNF like form, we can simply replace each block in the parent.
79    """
80    for predicate in predicates:
81        for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
82            if isinstance(node, exp.Join):
83                predicate.replace(exp.true())
84                node.on(predicate, copy=False)
85                break
86            if isinstance(node, exp.Select):
87                predicate.replace(exp.true())
88                inner_predicate = replace_aliases(node, predicate)
89                if find_in_scope(inner_predicate, exp.AggFunc):
90                    node.having(inner_predicate, copy=False)
91                else:
92                    node.where(inner_predicate, copy=False)

If the predicates are in CNF like form, we can simply replace each block in the parent.

def pushdown_dnf(predicates, scope, scope_ref_count):
 95def pushdown_dnf(predicates, scope, scope_ref_count):
 96    """
 97    If the predicates are in DNF form, we can only push down conditions that are in all blocks.
 98    Additionally, we can't remove predicates from their original form.
 99    """
100    # find all the tables that can be pushdown too
101    # these are tables that are referenced in all blocks of a DNF
102    # (a.x AND b.x) OR (a.y AND c.y)
103    # only table a can be push down
104    pushdown_tables = set()
105
106    for a in predicates:
107        a_tables = exp.column_table_names(a)
108
109        for b in predicates:
110            a_tables &= exp.column_table_names(b)
111
112        pushdown_tables.update(a_tables)
113
114    conditions = {}
115
116    # pushdown all predicates to their respective nodes
117    for table in sorted(pushdown_tables):
118        for predicate in predicates:
119            nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
120
121            if table not in nodes:
122                continue
123
124            conditions[table] = (
125                exp.or_(conditions[table], predicate) if table in conditions else predicate
126            )
127
128        for name, node in nodes.items():
129            if name not in conditions:
130                continue
131
132            predicate = conditions[name]
133
134            if isinstance(node, exp.Join):
135                node.on(predicate, copy=False)
136            elif isinstance(node, exp.Select):
137                inner_predicate = replace_aliases(node, predicate)
138                if find_in_scope(inner_predicate, exp.AggFunc):
139                    node.having(inner_predicate, copy=False)
140                else:
141                    node.where(inner_predicate, copy=False)

If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form.

def nodes_for_predicate(predicate, sources, scope_ref_count):
144def nodes_for_predicate(predicate, sources, scope_ref_count):
145    nodes = {}
146    tables = exp.column_table_names(predicate)
147    where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
148
149    for table in sorted(tables):
150        node, source = sources.get(table) or (None, None)
151
152        # if the predicate is in a where statement we can try to push it down
153        # we want to find the root join or from statement
154        if node and where_condition:
155            node = node.find_ancestor(exp.Join, exp.From)
156
157        # a node can reference a CTE which should be pushed down
158        if isinstance(node, exp.From) and not isinstance(source, exp.Table):
159            with_ = source.parent.expression.args.get("with")
160            if with_ and with_.recursive:
161                return {}
162            node = source.expression
163
164        if isinstance(node, exp.Join):
165            if node.side and node.side != "RIGHT":
166                return {}
167            nodes[table] = node
168        elif isinstance(node, exp.Select) and len(tables) == 1:
169            # We can't push down window expressions
170            has_window_expression = any(
171                select for select in node.selects if select.find(exp.Window)
172            )
173            # we can't push down predicates to select statements if they are referenced in
174            # multiple places.
175            if (
176                not node.args.get("group")
177                and scope_ref_count[id(source)] < 2
178                and not has_window_expression
179            ):
180                nodes[table] = node
181    return nodes
def replace_aliases(source, predicate):
184def replace_aliases(source, predicate):
185    aliases = {}
186
187    for select in source.selects:
188        if isinstance(select, exp.Alias):
189            aliases[select.alias] = select.this
190        else:
191            aliases[select.name] = select
192
193    def _replace_alias(column):
194        if isinstance(column, exp.Column) and column.name in aliases:
195            return aliases[column.name].copy()
196        return column
197
198    return predicate.transform(_replace_alias)