Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11# Final means that an expression should not be simplified
 12FINAL = "final"
 13
 14
 15def simplify(expression):
 16    """
 17    Rewrite sqlglot AST to simplify expressions.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 22        >>> simplify(expression).sql()
 23        'TRUE'
 24
 25    Args:
 26        expression (sqlglot.Expression): expression to simplify
 27    Returns:
 28        sqlglot.Expression: simplified expression
 29    """
 30
 31    generate = cached_generator()
 32
 33    # group by expressions cannot be simplified, for example
 34    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 35    # the projection must exactly match the group by key
 36    for group in expression.find_all(exp.Group):
 37        select = group.parent
 38        groups = set(group.expressions)
 39        group.meta[FINAL] = True
 40
 41        for e in select.selects:
 42            for node, *_ in e.walk():
 43                if node in groups:
 44                    e.meta[FINAL] = True
 45                    break
 46
 47        having = select.args.get("having")
 48        if having:
 49            for node, *_ in having.walk():
 50                if node in groups:
 51                    having.meta[FINAL] = True
 52                    break
 53
 54    def _simplify(expression, root=True):
 55        if expression.meta.get(FINAL):
 56            return expression
 57
 58        # Pre-order transformations
 59        node = expression
 60        node = rewrite_between(node)
 61        node = uniq_sort(node, generate, root)
 62        node = absorb_and_eliminate(node, root)
 63        node = simplify_concat(node)
 64
 65        exp.replace_children(node, lambda e: _simplify(e, False))
 66
 67        # Post-order transformations
 68        node = simplify_not(node)
 69        node = flatten(node)
 70        node = simplify_connectors(node, root)
 71        node = remove_compliments(node, root)
 72        node = simplify_coalesce(node)
 73        node.parent = expression.parent
 74        node = simplify_literals(node, root)
 75        node = simplify_parens(node)
 76
 77        if root:
 78            expression.replace(node)
 79
 80        return node
 81
 82    expression = while_changing(expression, _simplify)
 83    remove_where_true(expression)
 84    return expression
 85
 86
 87def rewrite_between(expression: exp.Expression) -> exp.Expression:
 88    """Rewrite x between y and z to x >= y AND x <= z.
 89
 90    This is done because comparison simplification is only done on lt/lte/gt/gte.
 91    """
 92    if isinstance(expression, exp.Between):
 93        return exp.and_(
 94            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 95            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 96            copy=False,
 97        )
 98    return expression
 99
100
101def simplify_not(expression):
102    """
103    Demorgan's Law
104    NOT (x OR y) -> NOT x AND NOT y
105    NOT (x AND y) -> NOT x OR NOT y
106    """
107    if isinstance(expression, exp.Not):
108        if is_null(expression.this):
109            return exp.null()
110        if isinstance(expression.this, exp.Paren):
111            condition = expression.this.unnest()
112            if isinstance(condition, exp.And):
113                return exp.or_(
114                    exp.not_(condition.left, copy=False),
115                    exp.not_(condition.right, copy=False),
116                    copy=False,
117                )
118            if isinstance(condition, exp.Or):
119                return exp.and_(
120                    exp.not_(condition.left, copy=False),
121                    exp.not_(condition.right, copy=False),
122                    copy=False,
123                )
124            if is_null(condition):
125                return exp.null()
126        if always_true(expression.this):
127            return exp.false()
128        if is_false(expression.this):
129            return exp.true()
130        if isinstance(expression.this, exp.Not):
131            # double negation
132            # NOT NOT x -> x
133            return expression.this.this
134    return expression
135
136
137def flatten(expression):
138    """
139    A AND (B AND C) -> A AND B AND C
140    A OR (B OR C) -> A OR B OR C
141    """
142    if isinstance(expression, exp.Connector):
143        for node in expression.args.values():
144            child = node.unnest()
145            if isinstance(child, expression.__class__):
146                node.replace(child)
147    return expression
148
149
150def simplify_connectors(expression, root=True):
151    def _simplify_connectors(expression, left, right):
152        if left == right:
153            return left
154        if isinstance(expression, exp.And):
155            if is_false(left) or is_false(right):
156                return exp.false()
157            if is_null(left) or is_null(right):
158                return exp.null()
159            if always_true(left) and always_true(right):
160                return exp.true()
161            if always_true(left):
162                return right
163            if always_true(right):
164                return left
165            return _simplify_comparison(expression, left, right)
166        elif isinstance(expression, exp.Or):
167            if always_true(left) or always_true(right):
168                return exp.true()
169            if is_false(left) and is_false(right):
170                return exp.false()
171            if (
172                (is_null(left) and is_null(right))
173                or (is_null(left) and is_false(right))
174                or (is_false(left) and is_null(right))
175            ):
176                return exp.null()
177            if is_false(left):
178                return right
179            if is_false(right):
180                return left
181            return _simplify_comparison(expression, left, right, or_=True)
182
183    if isinstance(expression, exp.Connector):
184        return _flat_simplify(expression, _simplify_connectors, root)
185    return expression
186
187
188LT_LTE = (exp.LT, exp.LTE)
189GT_GTE = (exp.GT, exp.GTE)
190
191COMPARISONS = (
192    *LT_LTE,
193    *GT_GTE,
194    exp.EQ,
195    exp.NEQ,
196    exp.Is,
197)
198
199INVERSE_COMPARISONS = {
200    exp.LT: exp.GT,
201    exp.GT: exp.LT,
202    exp.LTE: exp.GTE,
203    exp.GTE: exp.LTE,
204}
205
206
207def _simplify_comparison(expression, left, right, or_=False):
208    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
209        ll, lr = left.args.values()
210        rl, rr = right.args.values()
211
212        largs = {ll, lr}
213        rargs = {rl, rr}
214
215        matching = largs & rargs
216        columns = {m for m in matching if isinstance(m, exp.Column)}
217
218        if matching and columns:
219            try:
220                l = first(largs - columns)
221                r = first(rargs - columns)
222            except StopIteration:
223                return expression
224
225            # make sure the comparison is always of the form x > 1 instead of 1 < x
226            if left.__class__ in INVERSE_COMPARISONS and l == ll:
227                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
228            if right.__class__ in INVERSE_COMPARISONS and r == rl:
229                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
230
231            if l.is_number and r.is_number:
232                l = float(l.name)
233                r = float(r.name)
234            elif l.is_string and r.is_string:
235                l = l.name
236                r = r.name
237            else:
238                return None
239
240            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
241                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
242                    return left if (av > bv if or_ else av <= bv) else right
243                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
244                    return left if (av < bv if or_ else av >= bv) else right
245
246                # we can't ever shortcut to true because the column could be null
247                if not or_:
248                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
249                        if av <= bv:
250                            return exp.false()
251                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
252                        if av >= bv:
253                            return exp.false()
254                    elif isinstance(a, exp.EQ):
255                        if isinstance(b, exp.LT):
256                            return exp.false() if av >= bv else a
257                        if isinstance(b, exp.LTE):
258                            return exp.false() if av > bv else a
259                        if isinstance(b, exp.GT):
260                            return exp.false() if av <= bv else a
261                        if isinstance(b, exp.GTE):
262                            return exp.false() if av < bv else a
263                        if isinstance(b, exp.NEQ):
264                            return exp.false() if av == bv else a
265    return None
266
267
268def remove_compliments(expression, root=True):
269    """
270    Removing compliments.
271
272    A AND NOT A -> FALSE
273    A OR NOT A -> TRUE
274    """
275    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
276        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
277
278        for a, b in itertools.permutations(expression.flatten(), 2):
279            if is_complement(a, b):
280                return compliment
281    return expression
282
283
284def uniq_sort(expression, generate, root=True):
285    """
286    Uniq and sort a connector.
287
288    C AND A AND B AND B -> A AND B AND C
289    """
290    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
291        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
292        flattened = tuple(expression.flatten())
293        deduped = {generate(e): e for e in flattened}
294        arr = tuple(deduped.items())
295
296        # check if the operands are already sorted, if not sort them
297        # A AND C AND B -> A AND B AND C
298        for i, (sql, e) in enumerate(arr[1:]):
299            if sql < arr[i][0]:
300                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
301                break
302        else:
303            # we didn't have to sort but maybe we need to dedup
304            if len(deduped) < len(flattened):
305                expression = result_func(*deduped.values(), copy=False)
306
307    return expression
308
309
310def absorb_and_eliminate(expression, root=True):
311    """
312    absorption:
313        A AND (A OR B) -> A
314        A OR (A AND B) -> A
315        A AND (NOT A OR B) -> A AND B
316        A OR (NOT A AND B) -> A OR B
317    elimination:
318        (A AND B) OR (A AND NOT B) -> A
319        (A OR B) AND (A OR NOT B) -> A
320    """
321    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
322        kind = exp.Or if isinstance(expression, exp.And) else exp.And
323
324        for a, b in itertools.permutations(expression.flatten(), 2):
325            if isinstance(a, kind):
326                aa, ab = a.unnest_operands()
327
328                # absorb
329                if is_complement(b, aa):
330                    aa.replace(exp.true() if kind == exp.And else exp.false())
331                elif is_complement(b, ab):
332                    ab.replace(exp.true() if kind == exp.And else exp.false())
333                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
334                    a.replace(exp.false() if kind == exp.And else exp.true())
335                elif isinstance(b, kind):
336                    # eliminate
337                    rhs = b.unnest_operands()
338                    ba, bb = rhs
339
340                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
341                        a.replace(aa)
342                        b.replace(aa)
343                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
344                        a.replace(ab)
345                        b.replace(ab)
346
347    return expression
348
349
350def simplify_literals(expression, root=True):
351    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
352        return _flat_simplify(expression, _simplify_binary, root)
353    elif isinstance(expression, exp.Neg):
354        this = expression.this
355        if this.is_number:
356            value = this.name
357            if value[0] == "-":
358                return exp.Literal.number(value[1:])
359            return exp.Literal.number(f"-{value}")
360
361    return expression
362
363
364def _simplify_binary(expression, a, b):
365    if isinstance(expression, exp.Is):
366        if isinstance(b, exp.Not):
367            c = b.this
368            not_ = True
369        else:
370            c = b
371            not_ = False
372
373        if is_null(c):
374            if isinstance(a, exp.Literal):
375                return exp.true() if not_ else exp.false()
376            if is_null(a):
377                return exp.false() if not_ else exp.true()
378    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
379        return None
380    elif is_null(a) or is_null(b):
381        return exp.null()
382
383    if a.is_number and b.is_number:
384        a = int(a.name) if a.is_int else Decimal(a.name)
385        b = int(b.name) if b.is_int else Decimal(b.name)
386
387        if isinstance(expression, exp.Add):
388            return exp.Literal.number(a + b)
389        if isinstance(expression, exp.Sub):
390            return exp.Literal.number(a - b)
391        if isinstance(expression, exp.Mul):
392            return exp.Literal.number(a * b)
393        if isinstance(expression, exp.Div):
394            # engines have differing int div behavior so intdiv is not safe
395            if isinstance(a, int) and isinstance(b, int):
396                return None
397            return exp.Literal.number(a / b)
398
399        boolean = eval_boolean(expression, a, b)
400
401        if boolean:
402            return boolean
403    elif a.is_string and b.is_string:
404        boolean = eval_boolean(expression, a.this, b.this)
405
406        if boolean:
407            return boolean
408    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
409        a, b = extract_date(a), extract_interval(b)
410        if a and b:
411            if isinstance(expression, exp.Add):
412                return date_literal(a + b)
413            if isinstance(expression, exp.Sub):
414                return date_literal(a - b)
415    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
416        a, b = extract_interval(a), extract_date(b)
417        # you cannot subtract a date from an interval
418        if a and b and isinstance(expression, exp.Add):
419            return date_literal(a + b)
420
421    return None
422
423
424def simplify_parens(expression):
425    if not isinstance(expression, exp.Paren):
426        return expression
427
428    this = expression.this
429    parent = expression.parent
430
431    if not isinstance(this, exp.Select) and (
432        not isinstance(parent, (exp.Condition, exp.Binary))
433        or isinstance(this, exp.Predicate)
434        or isinstance(parent, exp.Paren)
435        or not isinstance(this, exp.Binary)
436        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
437        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
438        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
439    ):
440        return this
441    return expression
442
443
444CONSTANTS = (
445    exp.Literal,
446    exp.Boolean,
447    exp.Null,
448)
449
450
451def simplify_coalesce(expression):
452    # COALESCE(x) -> x
453    if (
454        isinstance(expression, exp.Coalesce)
455        and not expression.expressions
456        # COALESCE is also used as a Spark partitioning hint
457        and not isinstance(expression.parent, exp.Hint)
458    ):
459        return expression.this
460
461    if not isinstance(expression, COMPARISONS):
462        return expression
463
464    if isinstance(expression.left, exp.Coalesce):
465        coalesce = expression.left
466        other = expression.right
467    elif isinstance(expression.right, exp.Coalesce):
468        coalesce = expression.right
469        other = expression.left
470    else:
471        return expression
472
473    # This transformation is valid for non-constants,
474    # but it really only does anything if they are both constants.
475    if not isinstance(other, CONSTANTS):
476        return expression
477
478    # Find the first constant arg
479    for arg_index, arg in enumerate(coalesce.expressions):
480        if isinstance(arg, CONSTANTS):
481            break
482    else:
483        return expression
484
485    coalesce.set("expressions", coalesce.expressions[:arg_index])
486
487    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
488    # since we already remove COALESCE at the top of this function.
489    coalesce = coalesce if coalesce.expressions else coalesce.this
490
491    # This expression is more complex than when we started, but it will get simplified further
492    return exp.paren(
493        exp.or_(
494            exp.and_(
495                coalesce.is_(exp.null()).not_(copy=False),
496                expression.copy(),
497                copy=False,
498            ),
499            exp.and_(
500                coalesce.is_(exp.null()),
501                type(expression)(this=arg.copy(), expression=other.copy()),
502                copy=False,
503            ),
504            copy=False,
505        )
506    )
507
508
509CONCATS = (exp.Concat, exp.DPipe)
510SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
511
512
513def simplify_concat(expression):
514    """Reduces all groups that contain string literals by concatenating them."""
515    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
516        return expression
517
518    new_args = []
519    for is_string_group, group in itertools.groupby(
520        expression.expressions or expression.flatten(), lambda e: e.is_string
521    ):
522        if is_string_group:
523            new_args.append(exp.Literal.string("".join(string.name for string in group)))
524        else:
525            new_args.extend(group)
526
527    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
528    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
529    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
530
531
532# CROSS joins result in an empty table if the right table is empty.
533# So we can only simplify certain types of joins to CROSS.
534# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
535JOINS = {
536    ("", ""),
537    ("", "INNER"),
538    ("RIGHT", ""),
539    ("RIGHT", "OUTER"),
540}
541
542
543def remove_where_true(expression):
544    for where in expression.find_all(exp.Where):
545        if always_true(where.this):
546            where.parent.set("where", None)
547    for join in expression.find_all(exp.Join):
548        if (
549            always_true(join.args.get("on"))
550            and not join.args.get("using")
551            and not join.args.get("method")
552            and (join.side, join.kind) in JOINS
553        ):
554            join.set("on", None)
555            join.set("side", None)
556            join.set("kind", "CROSS")
557
558
559def always_true(expression):
560    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
561        expression, exp.Literal
562    )
563
564
565def is_complement(a, b):
566    return isinstance(b, exp.Not) and b.this == a
567
568
569def is_false(a: exp.Expression) -> bool:
570    return type(a) is exp.Boolean and not a.this
571
572
573def is_null(a: exp.Expression) -> bool:
574    return type(a) is exp.Null
575
576
577def eval_boolean(expression, a, b):
578    if isinstance(expression, (exp.EQ, exp.Is)):
579        return boolean_literal(a == b)
580    if isinstance(expression, exp.NEQ):
581        return boolean_literal(a != b)
582    if isinstance(expression, exp.GT):
583        return boolean_literal(a > b)
584    if isinstance(expression, exp.GTE):
585        return boolean_literal(a >= b)
586    if isinstance(expression, exp.LT):
587        return boolean_literal(a < b)
588    if isinstance(expression, exp.LTE):
589        return boolean_literal(a <= b)
590    return None
591
592
593def extract_date(cast):
594    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
595    # so in that case we can't extract the date.
596    try:
597        if cast.args["to"].this == exp.DataType.Type.DATE:
598            return datetime.date.fromisoformat(cast.name)
599        if cast.args["to"].this == exp.DataType.Type.DATETIME:
600            return datetime.datetime.fromisoformat(cast.name)
601    except ValueError:
602        return None
603
604
605def extract_interval(interval):
606    try:
607        from dateutil.relativedelta import relativedelta  # type: ignore
608    except ModuleNotFoundError:
609        return None
610
611    n = int(interval.name)
612    unit = interval.text("unit").lower()
613
614    if unit == "year":
615        return relativedelta(years=n)
616    if unit == "month":
617        return relativedelta(months=n)
618    if unit == "week":
619        return relativedelta(weeks=n)
620    if unit == "day":
621        return relativedelta(days=n)
622    return None
623
624
625def date_literal(date):
626    return exp.cast(
627        exp.Literal.string(date),
628        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
629    )
630
631
632def boolean_literal(condition):
633    return exp.true() if condition else exp.false()
634
635
636def _flat_simplify(expression, simplifier, root=True):
637    if root or not expression.same_parent:
638        operands = []
639        queue = deque(expression.flatten(unnest=False))
640        size = len(queue)
641
642        while queue:
643            a = queue.popleft()
644
645            for b in queue:
646                result = simplifier(expression, a, b)
647
648                if result:
649                    queue.remove(b)
650                    queue.appendleft(result)
651                    break
652            else:
653                operands.append(a)
654
655        if len(operands) < size:
656            return functools.reduce(
657                lambda a, b: expression.__class__(this=a, expression=b), operands
658            )
659    return expression
FINAL = 'final'
def simplify(expression):
16def simplify(expression):
17    """
18    Rewrite sqlglot AST to simplify expressions.
19
20    Example:
21        >>> import sqlglot
22        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
23        >>> simplify(expression).sql()
24        'TRUE'
25
26    Args:
27        expression (sqlglot.Expression): expression to simplify
28    Returns:
29        sqlglot.Expression: simplified expression
30    """
31
32    generate = cached_generator()
33
34    # group by expressions cannot be simplified, for example
35    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
36    # the projection must exactly match the group by key
37    for group in expression.find_all(exp.Group):
38        select = group.parent
39        groups = set(group.expressions)
40        group.meta[FINAL] = True
41
42        for e in select.selects:
43            for node, *_ in e.walk():
44                if node in groups:
45                    e.meta[FINAL] = True
46                    break
47
48        having = select.args.get("having")
49        if having:
50            for node, *_ in having.walk():
51                if node in groups:
52                    having.meta[FINAL] = True
53                    break
54
55    def _simplify(expression, root=True):
56        if expression.meta.get(FINAL):
57            return expression
58
59        # Pre-order transformations
60        node = expression
61        node = rewrite_between(node)
62        node = uniq_sort(node, generate, root)
63        node = absorb_and_eliminate(node, root)
64        node = simplify_concat(node)
65
66        exp.replace_children(node, lambda e: _simplify(e, False))
67
68        # Post-order transformations
69        node = simplify_not(node)
70        node = flatten(node)
71        node = simplify_connectors(node, root)
72        node = remove_compliments(node, root)
73        node = simplify_coalesce(node)
74        node.parent = expression.parent
75        node = simplify_literals(node, root)
76        node = simplify_parens(node)
77
78        if root:
79            expression.replace(node)
80
81        return node
82
83    expression = while_changing(expression, _simplify)
84    remove_where_true(expression)
85    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
88def rewrite_between(expression: exp.Expression) -> exp.Expression:
89    """Rewrite x between y and z to x >= y AND x <= z.
90
91    This is done because comparison simplification is only done on lt/lte/gt/gte.
92    """
93    if isinstance(expression, exp.Between):
94        return exp.and_(
95            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
96            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
97            copy=False,
98        )
99    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
102def simplify_not(expression):
103    """
104    Demorgan's Law
105    NOT (x OR y) -> NOT x AND NOT y
106    NOT (x AND y) -> NOT x OR NOT y
107    """
108    if isinstance(expression, exp.Not):
109        if is_null(expression.this):
110            return exp.null()
111        if isinstance(expression.this, exp.Paren):
112            condition = expression.this.unnest()
113            if isinstance(condition, exp.And):
114                return exp.or_(
115                    exp.not_(condition.left, copy=False),
116                    exp.not_(condition.right, copy=False),
117                    copy=False,
118                )
119            if isinstance(condition, exp.Or):
120                return exp.and_(
121                    exp.not_(condition.left, copy=False),
122                    exp.not_(condition.right, copy=False),
123                    copy=False,
124                )
125            if is_null(condition):
126                return exp.null()
127        if always_true(expression.this):
128            return exp.false()
129        if is_false(expression.this):
130            return exp.true()
131        if isinstance(expression.this, exp.Not):
132            # double negation
133            # NOT NOT x -> x
134            return expression.this.this
135    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
138def flatten(expression):
139    """
140    A AND (B AND C) -> A AND B AND C
141    A OR (B OR C) -> A OR B OR C
142    """
143    if isinstance(expression, exp.Connector):
144        for node in expression.args.values():
145            child = node.unnest()
146            if isinstance(child, expression.__class__):
147                node.replace(child)
148    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
151def simplify_connectors(expression, root=True):
152    def _simplify_connectors(expression, left, right):
153        if left == right:
154            return left
155        if isinstance(expression, exp.And):
156            if is_false(left) or is_false(right):
157                return exp.false()
158            if is_null(left) or is_null(right):
159                return exp.null()
160            if always_true(left) and always_true(right):
161                return exp.true()
162            if always_true(left):
163                return right
164            if always_true(right):
165                return left
166            return _simplify_comparison(expression, left, right)
167        elif isinstance(expression, exp.Or):
168            if always_true(left) or always_true(right):
169                return exp.true()
170            if is_false(left) and is_false(right):
171                return exp.false()
172            if (
173                (is_null(left) and is_null(right))
174                or (is_null(left) and is_false(right))
175                or (is_false(left) and is_null(right))
176            ):
177                return exp.null()
178            if is_false(left):
179                return right
180            if is_false(right):
181                return left
182            return _simplify_comparison(expression, left, right, or_=True)
183
184    if isinstance(expression, exp.Connector):
185        return _flat_simplify(expression, _simplify_connectors, root)
186    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
INVERSE_COMPARISONS = {<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def remove_compliments(expression, root=True):
269def remove_compliments(expression, root=True):
270    """
271    Removing compliments.
272
273    A AND NOT A -> FALSE
274    A OR NOT A -> TRUE
275    """
276    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
277        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
278
279        for a, b in itertools.permutations(expression.flatten(), 2):
280            if is_complement(a, b):
281                return compliment
282    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
285def uniq_sort(expression, generate, root=True):
286    """
287    Uniq and sort a connector.
288
289    C AND A AND B AND B -> A AND B AND C
290    """
291    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
292        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
293        flattened = tuple(expression.flatten())
294        deduped = {generate(e): e for e in flattened}
295        arr = tuple(deduped.items())
296
297        # check if the operands are already sorted, if not sort them
298        # A AND C AND B -> A AND B AND C
299        for i, (sql, e) in enumerate(arr[1:]):
300            if sql < arr[i][0]:
301                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
302                break
303        else:
304            # we didn't have to sort but maybe we need to dedup
305            if len(deduped) < len(flattened):
306                expression = result_func(*deduped.values(), copy=False)
307
308    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
311def absorb_and_eliminate(expression, root=True):
312    """
313    absorption:
314        A AND (A OR B) -> A
315        A OR (A AND B) -> A
316        A AND (NOT A OR B) -> A AND B
317        A OR (NOT A AND B) -> A OR B
318    elimination:
319        (A AND B) OR (A AND NOT B) -> A
320        (A OR B) AND (A OR NOT B) -> A
321    """
322    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
323        kind = exp.Or if isinstance(expression, exp.And) else exp.And
324
325        for a, b in itertools.permutations(expression.flatten(), 2):
326            if isinstance(a, kind):
327                aa, ab = a.unnest_operands()
328
329                # absorb
330                if is_complement(b, aa):
331                    aa.replace(exp.true() if kind == exp.And else exp.false())
332                elif is_complement(b, ab):
333                    ab.replace(exp.true() if kind == exp.And else exp.false())
334                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
335                    a.replace(exp.false() if kind == exp.And else exp.true())
336                elif isinstance(b, kind):
337                    # eliminate
338                    rhs = b.unnest_operands()
339                    ba, bb = rhs
340
341                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
342                        a.replace(aa)
343                        b.replace(aa)
344                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
345                        a.replace(ab)
346                        b.replace(ab)
347
348    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
351def simplify_literals(expression, root=True):
352    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
353        return _flat_simplify(expression, _simplify_binary, root)
354    elif isinstance(expression, exp.Neg):
355        this = expression.this
356        if this.is_number:
357            value = this.name
358            if value[0] == "-":
359                return exp.Literal.number(value[1:])
360            return exp.Literal.number(f"-{value}")
361
362    return expression
def simplify_parens(expression):
425def simplify_parens(expression):
426    if not isinstance(expression, exp.Paren):
427        return expression
428
429    this = expression.this
430    parent = expression.parent
431
432    if not isinstance(this, exp.Select) and (
433        not isinstance(parent, (exp.Condition, exp.Binary))
434        or isinstance(this, exp.Predicate)
435        or isinstance(parent, exp.Paren)
436        or not isinstance(this, exp.Binary)
437        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
438        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
439        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
440    ):
441        return this
442    return expression
def simplify_coalesce(expression):
452def simplify_coalesce(expression):
453    # COALESCE(x) -> x
454    if (
455        isinstance(expression, exp.Coalesce)
456        and not expression.expressions
457        # COALESCE is also used as a Spark partitioning hint
458        and not isinstance(expression.parent, exp.Hint)
459    ):
460        return expression.this
461
462    if not isinstance(expression, COMPARISONS):
463        return expression
464
465    if isinstance(expression.left, exp.Coalesce):
466        coalesce = expression.left
467        other = expression.right
468    elif isinstance(expression.right, exp.Coalesce):
469        coalesce = expression.right
470        other = expression.left
471    else:
472        return expression
473
474    # This transformation is valid for non-constants,
475    # but it really only does anything if they are both constants.
476    if not isinstance(other, CONSTANTS):
477        return expression
478
479    # Find the first constant arg
480    for arg_index, arg in enumerate(coalesce.expressions):
481        if isinstance(arg, CONSTANTS):
482            break
483    else:
484        return expression
485
486    coalesce.set("expressions", coalesce.expressions[:arg_index])
487
488    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
489    # since we already remove COALESCE at the top of this function.
490    coalesce = coalesce if coalesce.expressions else coalesce.this
491
492    # This expression is more complex than when we started, but it will get simplified further
493    return exp.paren(
494        exp.or_(
495            exp.and_(
496                coalesce.is_(exp.null()).not_(copy=False),
497                expression.copy(),
498                copy=False,
499            ),
500            exp.and_(
501                coalesce.is_(exp.null()),
502                type(expression)(this=arg.copy(), expression=other.copy()),
503                copy=False,
504            ),
505            copy=False,
506        )
507    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
514def simplify_concat(expression):
515    """Reduces all groups that contain string literals by concatenating them."""
516    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
517        return expression
518
519    new_args = []
520    for is_string_group, group in itertools.groupby(
521        expression.expressions or expression.flatten(), lambda e: e.is_string
522    ):
523        if is_string_group:
524            new_args.append(exp.Literal.string("".join(string.name for string in group)))
525        else:
526            new_args.extend(group)
527
528    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
529    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
530    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

JOINS = {('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
544def remove_where_true(expression):
545    for where in expression.find_all(exp.Where):
546        if always_true(where.this):
547            where.parent.set("where", None)
548    for join in expression.find_all(exp.Join):
549        if (
550            always_true(join.args.get("on"))
551            and not join.args.get("using")
552            and not join.args.get("method")
553            and (join.side, join.kind) in JOINS
554        ):
555            join.set("on", None)
556            join.set("side", None)
557            join.set("kind", "CROSS")
def always_true(expression):
560def always_true(expression):
561    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
562        expression, exp.Literal
563    )
def is_complement(a, b):
566def is_complement(a, b):
567    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
570def is_false(a: exp.Expression) -> bool:
571    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
574def is_null(a: exp.Expression) -> bool:
575    return type(a) is exp.Null
def eval_boolean(expression, a, b):
578def eval_boolean(expression, a, b):
579    if isinstance(expression, (exp.EQ, exp.Is)):
580        return boolean_literal(a == b)
581    if isinstance(expression, exp.NEQ):
582        return boolean_literal(a != b)
583    if isinstance(expression, exp.GT):
584        return boolean_literal(a > b)
585    if isinstance(expression, exp.GTE):
586        return boolean_literal(a >= b)
587    if isinstance(expression, exp.LT):
588        return boolean_literal(a < b)
589    if isinstance(expression, exp.LTE):
590        return boolean_literal(a <= b)
591    return None
def extract_date(cast):
594def extract_date(cast):
595    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
596    # so in that case we can't extract the date.
597    try:
598        if cast.args["to"].this == exp.DataType.Type.DATE:
599            return datetime.date.fromisoformat(cast.name)
600        if cast.args["to"].this == exp.DataType.Type.DATETIME:
601            return datetime.datetime.fromisoformat(cast.name)
602    except ValueError:
603        return None
def extract_interval(interval):
606def extract_interval(interval):
607    try:
608        from dateutil.relativedelta import relativedelta  # type: ignore
609    except ModuleNotFoundError:
610        return None
611
612    n = int(interval.name)
613    unit = interval.text("unit").lower()
614
615    if unit == "year":
616        return relativedelta(years=n)
617    if unit == "month":
618        return relativedelta(months=n)
619    if unit == "week":
620        return relativedelta(weeks=n)
621    if unit == "day":
622        return relativedelta(days=n)
623    return None
def date_literal(date):
626def date_literal(date):
627    return exp.cast(
628        exp.Literal.string(date),
629        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
630    )
def boolean_literal(condition):
633def boolean_literal(condition):
634    return exp.true() if condition else exp.false()