Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11# Final means that an expression should not be simplified
 12FINAL = "final"
 13
 14
 15def simplify(expression):
 16    """
 17    Rewrite sqlglot AST to simplify expressions.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 22        >>> simplify(expression).sql()
 23        'TRUE'
 24
 25    Args:
 26        expression (sqlglot.Expression): expression to simplify
 27    Returns:
 28        sqlglot.Expression: simplified expression
 29    """
 30
 31    generate = cached_generator()
 32
 33    # group by expressions cannot be simplified, for example
 34    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 35    # the projection must exactly match the group by key
 36    for group in expression.find_all(exp.Group):
 37        select = group.parent
 38        groups = set(group.expressions)
 39        group.meta[FINAL] = True
 40
 41        for e in select.selects:
 42            for node, *_ in e.walk():
 43                if node in groups:
 44                    e.meta[FINAL] = True
 45                    break
 46
 47        having = select.args.get("having")
 48        if having:
 49            for node, *_ in having.walk():
 50                if node in groups:
 51                    having.meta[FINAL] = True
 52                    break
 53
 54    def _simplify(expression, root=True):
 55        if expression.meta.get(FINAL):
 56            return expression
 57
 58        # Pre-order transformations
 59        node = expression
 60        node = rewrite_between(node)
 61        node = uniq_sort(node, generate, root)
 62        node = absorb_and_eliminate(node, root)
 63        node = simplify_concat(node)
 64
 65        exp.replace_children(node, lambda e: _simplify(e, False))
 66
 67        # Post-order transformations
 68        node = simplify_not(node)
 69        node = flatten(node)
 70        node = simplify_connectors(node, root)
 71        node = remove_compliments(node, root)
 72        node.parent = expression.parent
 73        node = simplify_literals(node, root)
 74        node = simplify_parens(node)
 75        node = simplify_coalesce(node)
 76
 77        if root:
 78            expression.replace(node)
 79
 80        return node
 81
 82    expression = while_changing(expression, _simplify)
 83    remove_where_true(expression)
 84    return expression
 85
 86
 87def rewrite_between(expression: exp.Expression) -> exp.Expression:
 88    """Rewrite x between y and z to x >= y AND x <= z.
 89
 90    This is done because comparison simplification is only done on lt/lte/gt/gte.
 91    """
 92    if isinstance(expression, exp.Between):
 93        return exp.and_(
 94            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 95            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 96            copy=False,
 97        )
 98    return expression
 99
100
101def simplify_not(expression):
102    """
103    Demorgan's Law
104    NOT (x OR y) -> NOT x AND NOT y
105    NOT (x AND y) -> NOT x OR NOT y
106    """
107    if isinstance(expression, exp.Not):
108        if is_null(expression.this):
109            return exp.null()
110        if isinstance(expression.this, exp.Paren):
111            condition = expression.this.unnest()
112            if isinstance(condition, exp.And):
113                return exp.or_(
114                    exp.not_(condition.left, copy=False),
115                    exp.not_(condition.right, copy=False),
116                    copy=False,
117                )
118            if isinstance(condition, exp.Or):
119                return exp.and_(
120                    exp.not_(condition.left, copy=False),
121                    exp.not_(condition.right, copy=False),
122                    copy=False,
123                )
124            if is_null(condition):
125                return exp.null()
126        if always_true(expression.this):
127            return exp.false()
128        if is_false(expression.this):
129            return exp.true()
130        if isinstance(expression.this, exp.Not):
131            # double negation
132            # NOT NOT x -> x
133            return expression.this.this
134    return expression
135
136
137def flatten(expression):
138    """
139    A AND (B AND C) -> A AND B AND C
140    A OR (B OR C) -> A OR B OR C
141    """
142    if isinstance(expression, exp.Connector):
143        for node in expression.args.values():
144            child = node.unnest()
145            if isinstance(child, expression.__class__):
146                node.replace(child)
147    return expression
148
149
150def simplify_connectors(expression, root=True):
151    def _simplify_connectors(expression, left, right):
152        if left == right:
153            return left
154        if isinstance(expression, exp.And):
155            if is_false(left) or is_false(right):
156                return exp.false()
157            if is_null(left) or is_null(right):
158                return exp.null()
159            if always_true(left) and always_true(right):
160                return exp.true()
161            if always_true(left):
162                return right
163            if always_true(right):
164                return left
165            return _simplify_comparison(expression, left, right)
166        elif isinstance(expression, exp.Or):
167            if always_true(left) or always_true(right):
168                return exp.true()
169            if is_false(left) and is_false(right):
170                return exp.false()
171            if (
172                (is_null(left) and is_null(right))
173                or (is_null(left) and is_false(right))
174                or (is_false(left) and is_null(right))
175            ):
176                return exp.null()
177            if is_false(left):
178                return right
179            if is_false(right):
180                return left
181            return _simplify_comparison(expression, left, right, or_=True)
182
183    if isinstance(expression, exp.Connector):
184        return _flat_simplify(expression, _simplify_connectors, root)
185    return expression
186
187
188LT_LTE = (exp.LT, exp.LTE)
189GT_GTE = (exp.GT, exp.GTE)
190
191COMPARISONS = (
192    *LT_LTE,
193    *GT_GTE,
194    exp.EQ,
195    exp.NEQ,
196    exp.Is,
197)
198
199INVERSE_COMPARISONS = {
200    exp.LT: exp.GT,
201    exp.GT: exp.LT,
202    exp.LTE: exp.GTE,
203    exp.GTE: exp.LTE,
204}
205
206
207def _simplify_comparison(expression, left, right, or_=False):
208    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
209        ll, lr = left.args.values()
210        rl, rr = right.args.values()
211
212        largs = {ll, lr}
213        rargs = {rl, rr}
214
215        matching = largs & rargs
216        columns = {m for m in matching if isinstance(m, exp.Column)}
217
218        if matching and columns:
219            try:
220                l = first(largs - columns)
221                r = first(rargs - columns)
222            except StopIteration:
223                return expression
224
225            # make sure the comparison is always of the form x > 1 instead of 1 < x
226            if left.__class__ in INVERSE_COMPARISONS and l == ll:
227                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
228            if right.__class__ in INVERSE_COMPARISONS and r == rl:
229                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
230
231            if l.is_number and r.is_number:
232                l = float(l.name)
233                r = float(r.name)
234            elif l.is_string and r.is_string:
235                l = l.name
236                r = r.name
237            else:
238                return None
239
240            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
241                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
242                    return left if (av > bv if or_ else av <= bv) else right
243                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
244                    return left if (av < bv if or_ else av >= bv) else right
245
246                # we can't ever shortcut to true because the column could be null
247                if not or_:
248                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
249                        if av <= bv:
250                            return exp.false()
251                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
252                        if av >= bv:
253                            return exp.false()
254                    elif isinstance(a, exp.EQ):
255                        if isinstance(b, exp.LT):
256                            return exp.false() if av >= bv else a
257                        if isinstance(b, exp.LTE):
258                            return exp.false() if av > bv else a
259                        if isinstance(b, exp.GT):
260                            return exp.false() if av <= bv else a
261                        if isinstance(b, exp.GTE):
262                            return exp.false() if av < bv else a
263                        if isinstance(b, exp.NEQ):
264                            return exp.false() if av == bv else a
265    return None
266
267
268def remove_compliments(expression, root=True):
269    """
270    Removing compliments.
271
272    A AND NOT A -> FALSE
273    A OR NOT A -> TRUE
274    """
275    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
276        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
277
278        for a, b in itertools.permutations(expression.flatten(), 2):
279            if is_complement(a, b):
280                return compliment
281    return expression
282
283
284def uniq_sort(expression, generate, root=True):
285    """
286    Uniq and sort a connector.
287
288    C AND A AND B AND B -> A AND B AND C
289    """
290    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
291        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
292        flattened = tuple(expression.flatten())
293        deduped = {generate(e): e for e in flattened}
294        arr = tuple(deduped.items())
295
296        # check if the operands are already sorted, if not sort them
297        # A AND C AND B -> A AND B AND C
298        for i, (sql, e) in enumerate(arr[1:]):
299            if sql < arr[i][0]:
300                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
301                break
302        else:
303            # we didn't have to sort but maybe we need to dedup
304            if len(deduped) < len(flattened):
305                expression = result_func(*deduped.values(), copy=False)
306
307    return expression
308
309
310def absorb_and_eliminate(expression, root=True):
311    """
312    absorption:
313        A AND (A OR B) -> A
314        A OR (A AND B) -> A
315        A AND (NOT A OR B) -> A AND B
316        A OR (NOT A AND B) -> A OR B
317    elimination:
318        (A AND B) OR (A AND NOT B) -> A
319        (A OR B) AND (A OR NOT B) -> A
320    """
321    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
322        kind = exp.Or if isinstance(expression, exp.And) else exp.And
323
324        for a, b in itertools.permutations(expression.flatten(), 2):
325            if isinstance(a, kind):
326                aa, ab = a.unnest_operands()
327
328                # absorb
329                if is_complement(b, aa):
330                    aa.replace(exp.true() if kind == exp.And else exp.false())
331                elif is_complement(b, ab):
332                    ab.replace(exp.true() if kind == exp.And else exp.false())
333                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
334                    a.replace(exp.false() if kind == exp.And else exp.true())
335                elif isinstance(b, kind):
336                    # eliminate
337                    rhs = b.unnest_operands()
338                    ba, bb = rhs
339
340                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
341                        a.replace(aa)
342                        b.replace(aa)
343                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
344                        a.replace(ab)
345                        b.replace(ab)
346
347    return expression
348
349
350def simplify_literals(expression, root=True):
351    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
352        return _flat_simplify(expression, _simplify_binary, root)
353    elif isinstance(expression, exp.Neg):
354        this = expression.this
355        if this.is_number:
356            value = this.name
357            if value[0] == "-":
358                return exp.Literal.number(value[1:])
359            return exp.Literal.number(f"-{value}")
360
361    return expression
362
363
364def _simplify_binary(expression, a, b):
365    if isinstance(expression, exp.Is):
366        if isinstance(b, exp.Not):
367            c = b.this
368            not_ = True
369        else:
370            c = b
371            not_ = False
372
373        if is_null(c):
374            if isinstance(a, exp.Literal):
375                return exp.true() if not_ else exp.false()
376            if is_null(a):
377                return exp.false() if not_ else exp.true()
378    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
379        return None
380    elif is_null(a) or is_null(b):
381        return exp.null()
382
383    if a.is_number and b.is_number:
384        a = int(a.name) if a.is_int else Decimal(a.name)
385        b = int(b.name) if b.is_int else Decimal(b.name)
386
387        if isinstance(expression, exp.Add):
388            return exp.Literal.number(a + b)
389        if isinstance(expression, exp.Sub):
390            return exp.Literal.number(a - b)
391        if isinstance(expression, exp.Mul):
392            return exp.Literal.number(a * b)
393        if isinstance(expression, exp.Div):
394            # engines have differing int div behavior so intdiv is not safe
395            if isinstance(a, int) and isinstance(b, int):
396                return None
397            return exp.Literal.number(a / b)
398
399        boolean = eval_boolean(expression, a, b)
400
401        if boolean:
402            return boolean
403    elif a.is_string and b.is_string:
404        boolean = eval_boolean(expression, a.this, b.this)
405
406        if boolean:
407            return boolean
408    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
409        a, b = extract_date(a), extract_interval(b)
410        if a and b:
411            if isinstance(expression, exp.Add):
412                return date_literal(a + b)
413            if isinstance(expression, exp.Sub):
414                return date_literal(a - b)
415    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
416        a, b = extract_interval(a), extract_date(b)
417        # you cannot subtract a date from an interval
418        if a and b and isinstance(expression, exp.Add):
419            return date_literal(a + b)
420
421    return None
422
423
424def simplify_parens(expression):
425    if not isinstance(expression, exp.Paren):
426        return expression
427
428    this = expression.this
429    parent = expression.parent
430
431    if not isinstance(this, exp.Select) and (
432        not isinstance(parent, (exp.Condition, exp.Binary))
433        or isinstance(this, exp.Predicate)
434        or not isinstance(this, exp.Binary)
435        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
436        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
437        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
438    ):
439        return expression.this
440    return expression
441
442
443CONSTANTS = (
444    exp.Literal,
445    exp.Boolean,
446    exp.Null,
447)
448
449
450def simplify_coalesce(expression):
451    # COALESCE(x) -> x
452    if (
453        isinstance(expression, exp.Coalesce)
454        and not expression.expressions
455        # COALESCE is also used as a Spark partitioning hint
456        and not isinstance(expression.parent, exp.Hint)
457    ):
458        return expression.this
459
460    if not isinstance(expression, COMPARISONS):
461        return expression
462
463    if isinstance(expression.left, exp.Coalesce):
464        coalesce = expression.left
465        other = expression.right
466    elif isinstance(expression.right, exp.Coalesce):
467        coalesce = expression.right
468        other = expression.left
469    else:
470        return expression
471
472    # This transformation is valid for non-constants,
473    # but it really only does anything if they are both constants.
474    if not isinstance(other, CONSTANTS):
475        return expression
476
477    # Find the first constant arg
478    for arg_index, arg in enumerate(coalesce.expressions):
479        if isinstance(arg, CONSTANTS):
480            break
481    else:
482        return expression
483
484    coalesce.set("expressions", coalesce.expressions[:arg_index])
485
486    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
487    # since we already remove COALESCE at the top of this function.
488    coalesce = coalesce if coalesce.expressions else coalesce.this
489
490    # This expression is more complex than when we started, but it will get simplified further
491    return exp.or_(
492        exp.and_(
493            coalesce.is_(exp.null()).not_(copy=False),
494            expression.copy(),
495            copy=False,
496        ),
497        exp.and_(
498            coalesce.is_(exp.null()),
499            type(expression)(this=arg.copy(), expression=other.copy()),
500            copy=False,
501        ),
502        copy=False,
503    )
504
505
506CONCATS = (exp.Concat, exp.DPipe)
507SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
508
509
510def simplify_concat(expression):
511    """Reduces all groups that contain string literals by concatenating them."""
512    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
513        return expression
514
515    new_args = []
516    for is_string_group, group in itertools.groupby(
517        expression.expressions or expression.flatten(), lambda e: e.is_string
518    ):
519        if is_string_group:
520            new_args.append(exp.Literal.string("".join(string.name for string in group)))
521        else:
522            new_args.extend(group)
523
524    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
525    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
526    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
527
528
529# CROSS joins result in an empty table if the right table is empty.
530# So we can only simplify certain types of joins to CROSS.
531# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
532JOINS = {
533    ("", ""),
534    ("", "INNER"),
535    ("RIGHT", ""),
536    ("RIGHT", "OUTER"),
537}
538
539
540def remove_where_true(expression):
541    for where in expression.find_all(exp.Where):
542        if always_true(where.this):
543            where.parent.set("where", None)
544    for join in expression.find_all(exp.Join):
545        if (
546            always_true(join.args.get("on"))
547            and not join.args.get("using")
548            and not join.args.get("method")
549            and (join.side, join.kind) in JOINS
550        ):
551            join.set("on", None)
552            join.set("side", None)
553            join.set("kind", "CROSS")
554
555
556def always_true(expression):
557    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
558        expression, exp.Literal
559    )
560
561
562def is_complement(a, b):
563    return isinstance(b, exp.Not) and b.this == a
564
565
566def is_false(a: exp.Expression) -> bool:
567    return type(a) is exp.Boolean and not a.this
568
569
570def is_null(a: exp.Expression) -> bool:
571    return type(a) is exp.Null
572
573
574def eval_boolean(expression, a, b):
575    if isinstance(expression, (exp.EQ, exp.Is)):
576        return boolean_literal(a == b)
577    if isinstance(expression, exp.NEQ):
578        return boolean_literal(a != b)
579    if isinstance(expression, exp.GT):
580        return boolean_literal(a > b)
581    if isinstance(expression, exp.GTE):
582        return boolean_literal(a >= b)
583    if isinstance(expression, exp.LT):
584        return boolean_literal(a < b)
585    if isinstance(expression, exp.LTE):
586        return boolean_literal(a <= b)
587    return None
588
589
590def extract_date(cast):
591    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
592    # so in that case we can't extract the date.
593    try:
594        if cast.args["to"].this == exp.DataType.Type.DATE:
595            return datetime.date.fromisoformat(cast.name)
596        if cast.args["to"].this == exp.DataType.Type.DATETIME:
597            return datetime.datetime.fromisoformat(cast.name)
598    except ValueError:
599        return None
600
601
602def extract_interval(interval):
603    try:
604        from dateutil.relativedelta import relativedelta  # type: ignore
605    except ModuleNotFoundError:
606        return None
607
608    n = int(interval.name)
609    unit = interval.text("unit").lower()
610
611    if unit == "year":
612        return relativedelta(years=n)
613    if unit == "month":
614        return relativedelta(months=n)
615    if unit == "week":
616        return relativedelta(weeks=n)
617    if unit == "day":
618        return relativedelta(days=n)
619    return None
620
621
622def date_literal(date):
623    return exp.cast(
624        exp.Literal.string(date),
625        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
626    )
627
628
629def boolean_literal(condition):
630    return exp.true() if condition else exp.false()
631
632
633def _flat_simplify(expression, simplifier, root=True):
634    if root or not expression.same_parent:
635        operands = []
636        queue = deque(expression.flatten(unnest=False))
637        size = len(queue)
638
639        while queue:
640            a = queue.popleft()
641
642            for b in queue:
643                result = simplifier(expression, a, b)
644
645                if result:
646                    queue.remove(b)
647                    queue.appendleft(result)
648                    break
649            else:
650                operands.append(a)
651
652        if len(operands) < size:
653            return functools.reduce(
654                lambda a, b: expression.__class__(this=a, expression=b), operands
655            )
656    return expression
FINAL = 'final'
def simplify(expression):
16def simplify(expression):
17    """
18    Rewrite sqlglot AST to simplify expressions.
19
20    Example:
21        >>> import sqlglot
22        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
23        >>> simplify(expression).sql()
24        'TRUE'
25
26    Args:
27        expression (sqlglot.Expression): expression to simplify
28    Returns:
29        sqlglot.Expression: simplified expression
30    """
31
32    generate = cached_generator()
33
34    # group by expressions cannot be simplified, for example
35    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
36    # the projection must exactly match the group by key
37    for group in expression.find_all(exp.Group):
38        select = group.parent
39        groups = set(group.expressions)
40        group.meta[FINAL] = True
41
42        for e in select.selects:
43            for node, *_ in e.walk():
44                if node in groups:
45                    e.meta[FINAL] = True
46                    break
47
48        having = select.args.get("having")
49        if having:
50            for node, *_ in having.walk():
51                if node in groups:
52                    having.meta[FINAL] = True
53                    break
54
55    def _simplify(expression, root=True):
56        if expression.meta.get(FINAL):
57            return expression
58
59        # Pre-order transformations
60        node = expression
61        node = rewrite_between(node)
62        node = uniq_sort(node, generate, root)
63        node = absorb_and_eliminate(node, root)
64        node = simplify_concat(node)
65
66        exp.replace_children(node, lambda e: _simplify(e, False))
67
68        # Post-order transformations
69        node = simplify_not(node)
70        node = flatten(node)
71        node = simplify_connectors(node, root)
72        node = remove_compliments(node, root)
73        node.parent = expression.parent
74        node = simplify_literals(node, root)
75        node = simplify_parens(node)
76        node = simplify_coalesce(node)
77
78        if root:
79            expression.replace(node)
80
81        return node
82
83    expression = while_changing(expression, _simplify)
84    remove_where_true(expression)
85    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
88def rewrite_between(expression: exp.Expression) -> exp.Expression:
89    """Rewrite x between y and z to x >= y AND x <= z.
90
91    This is done because comparison simplification is only done on lt/lte/gt/gte.
92    """
93    if isinstance(expression, exp.Between):
94        return exp.and_(
95            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
96            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
97            copy=False,
98        )
99    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
102def simplify_not(expression):
103    """
104    Demorgan's Law
105    NOT (x OR y) -> NOT x AND NOT y
106    NOT (x AND y) -> NOT x OR NOT y
107    """
108    if isinstance(expression, exp.Not):
109        if is_null(expression.this):
110            return exp.null()
111        if isinstance(expression.this, exp.Paren):
112            condition = expression.this.unnest()
113            if isinstance(condition, exp.And):
114                return exp.or_(
115                    exp.not_(condition.left, copy=False),
116                    exp.not_(condition.right, copy=False),
117                    copy=False,
118                )
119            if isinstance(condition, exp.Or):
120                return exp.and_(
121                    exp.not_(condition.left, copy=False),
122                    exp.not_(condition.right, copy=False),
123                    copy=False,
124                )
125            if is_null(condition):
126                return exp.null()
127        if always_true(expression.this):
128            return exp.false()
129        if is_false(expression.this):
130            return exp.true()
131        if isinstance(expression.this, exp.Not):
132            # double negation
133            # NOT NOT x -> x
134            return expression.this.this
135    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
138def flatten(expression):
139    """
140    A AND (B AND C) -> A AND B AND C
141    A OR (B OR C) -> A OR B OR C
142    """
143    if isinstance(expression, exp.Connector):
144        for node in expression.args.values():
145            child = node.unnest()
146            if isinstance(child, expression.__class__):
147                node.replace(child)
148    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
151def simplify_connectors(expression, root=True):
152    def _simplify_connectors(expression, left, right):
153        if left == right:
154            return left
155        if isinstance(expression, exp.And):
156            if is_false(left) or is_false(right):
157                return exp.false()
158            if is_null(left) or is_null(right):
159                return exp.null()
160            if always_true(left) and always_true(right):
161                return exp.true()
162            if always_true(left):
163                return right
164            if always_true(right):
165                return left
166            return _simplify_comparison(expression, left, right)
167        elif isinstance(expression, exp.Or):
168            if always_true(left) or always_true(right):
169                return exp.true()
170            if is_false(left) and is_false(right):
171                return exp.false()
172            if (
173                (is_null(left) and is_null(right))
174                or (is_null(left) and is_false(right))
175                or (is_false(left) and is_null(right))
176            ):
177                return exp.null()
178            if is_false(left):
179                return right
180            if is_false(right):
181                return left
182            return _simplify_comparison(expression, left, right, or_=True)
183
184    if isinstance(expression, exp.Connector):
185        return _flat_simplify(expression, _simplify_connectors, root)
186    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
INVERSE_COMPARISONS = {<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def remove_compliments(expression, root=True):
269def remove_compliments(expression, root=True):
270    """
271    Removing compliments.
272
273    A AND NOT A -> FALSE
274    A OR NOT A -> TRUE
275    """
276    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
277        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
278
279        for a, b in itertools.permutations(expression.flatten(), 2):
280            if is_complement(a, b):
281                return compliment
282    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
285def uniq_sort(expression, generate, root=True):
286    """
287    Uniq and sort a connector.
288
289    C AND A AND B AND B -> A AND B AND C
290    """
291    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
292        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
293        flattened = tuple(expression.flatten())
294        deduped = {generate(e): e for e in flattened}
295        arr = tuple(deduped.items())
296
297        # check if the operands are already sorted, if not sort them
298        # A AND C AND B -> A AND B AND C
299        for i, (sql, e) in enumerate(arr[1:]):
300            if sql < arr[i][0]:
301                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
302                break
303        else:
304            # we didn't have to sort but maybe we need to dedup
305            if len(deduped) < len(flattened):
306                expression = result_func(*deduped.values(), copy=False)
307
308    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
311def absorb_and_eliminate(expression, root=True):
312    """
313    absorption:
314        A AND (A OR B) -> A
315        A OR (A AND B) -> A
316        A AND (NOT A OR B) -> A AND B
317        A OR (NOT A AND B) -> A OR B
318    elimination:
319        (A AND B) OR (A AND NOT B) -> A
320        (A OR B) AND (A OR NOT B) -> A
321    """
322    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
323        kind = exp.Or if isinstance(expression, exp.And) else exp.And
324
325        for a, b in itertools.permutations(expression.flatten(), 2):
326            if isinstance(a, kind):
327                aa, ab = a.unnest_operands()
328
329                # absorb
330                if is_complement(b, aa):
331                    aa.replace(exp.true() if kind == exp.And else exp.false())
332                elif is_complement(b, ab):
333                    ab.replace(exp.true() if kind == exp.And else exp.false())
334                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
335                    a.replace(exp.false() if kind == exp.And else exp.true())
336                elif isinstance(b, kind):
337                    # eliminate
338                    rhs = b.unnest_operands()
339                    ba, bb = rhs
340
341                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
342                        a.replace(aa)
343                        b.replace(aa)
344                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
345                        a.replace(ab)
346                        b.replace(ab)
347
348    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
351def simplify_literals(expression, root=True):
352    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
353        return _flat_simplify(expression, _simplify_binary, root)
354    elif isinstance(expression, exp.Neg):
355        this = expression.this
356        if this.is_number:
357            value = this.name
358            if value[0] == "-":
359                return exp.Literal.number(value[1:])
360            return exp.Literal.number(f"-{value}")
361
362    return expression
def simplify_parens(expression):
425def simplify_parens(expression):
426    if not isinstance(expression, exp.Paren):
427        return expression
428
429    this = expression.this
430    parent = expression.parent
431
432    if not isinstance(this, exp.Select) and (
433        not isinstance(parent, (exp.Condition, exp.Binary))
434        or isinstance(this, exp.Predicate)
435        or not isinstance(this, exp.Binary)
436        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
437        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
438        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
439    ):
440        return expression.this
441    return expression
def simplify_coalesce(expression):
451def simplify_coalesce(expression):
452    # COALESCE(x) -> x
453    if (
454        isinstance(expression, exp.Coalesce)
455        and not expression.expressions
456        # COALESCE is also used as a Spark partitioning hint
457        and not isinstance(expression.parent, exp.Hint)
458    ):
459        return expression.this
460
461    if not isinstance(expression, COMPARISONS):
462        return expression
463
464    if isinstance(expression.left, exp.Coalesce):
465        coalesce = expression.left
466        other = expression.right
467    elif isinstance(expression.right, exp.Coalesce):
468        coalesce = expression.right
469        other = expression.left
470    else:
471        return expression
472
473    # This transformation is valid for non-constants,
474    # but it really only does anything if they are both constants.
475    if not isinstance(other, CONSTANTS):
476        return expression
477
478    # Find the first constant arg
479    for arg_index, arg in enumerate(coalesce.expressions):
480        if isinstance(arg, CONSTANTS):
481            break
482    else:
483        return expression
484
485    coalesce.set("expressions", coalesce.expressions[:arg_index])
486
487    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
488    # since we already remove COALESCE at the top of this function.
489    coalesce = coalesce if coalesce.expressions else coalesce.this
490
491    # This expression is more complex than when we started, but it will get simplified further
492    return exp.or_(
493        exp.and_(
494            coalesce.is_(exp.null()).not_(copy=False),
495            expression.copy(),
496            copy=False,
497        ),
498        exp.and_(
499            coalesce.is_(exp.null()),
500            type(expression)(this=arg.copy(), expression=other.copy()),
501            copy=False,
502        ),
503        copy=False,
504    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
511def simplify_concat(expression):
512    """Reduces all groups that contain string literals by concatenating them."""
513    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
514        return expression
515
516    new_args = []
517    for is_string_group, group in itertools.groupby(
518        expression.expressions or expression.flatten(), lambda e: e.is_string
519    ):
520        if is_string_group:
521            new_args.append(exp.Literal.string("".join(string.name for string in group)))
522        else:
523            new_args.extend(group)
524
525    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
526    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
527    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

JOINS = {('RIGHT', 'OUTER'), ('RIGHT', ''), ('', 'INNER'), ('', '')}
def remove_where_true(expression):
541def remove_where_true(expression):
542    for where in expression.find_all(exp.Where):
543        if always_true(where.this):
544            where.parent.set("where", None)
545    for join in expression.find_all(exp.Join):
546        if (
547            always_true(join.args.get("on"))
548            and not join.args.get("using")
549            and not join.args.get("method")
550            and (join.side, join.kind) in JOINS
551        ):
552            join.set("on", None)
553            join.set("side", None)
554            join.set("kind", "CROSS")
def always_true(expression):
557def always_true(expression):
558    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
559        expression, exp.Literal
560    )
def is_complement(a, b):
563def is_complement(a, b):
564    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
567def is_false(a: exp.Expression) -> bool:
568    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
571def is_null(a: exp.Expression) -> bool:
572    return type(a) is exp.Null
def eval_boolean(expression, a, b):
575def eval_boolean(expression, a, b):
576    if isinstance(expression, (exp.EQ, exp.Is)):
577        return boolean_literal(a == b)
578    if isinstance(expression, exp.NEQ):
579        return boolean_literal(a != b)
580    if isinstance(expression, exp.GT):
581        return boolean_literal(a > b)
582    if isinstance(expression, exp.GTE):
583        return boolean_literal(a >= b)
584    if isinstance(expression, exp.LT):
585        return boolean_literal(a < b)
586    if isinstance(expression, exp.LTE):
587        return boolean_literal(a <= b)
588    return None
def extract_date(cast):
591def extract_date(cast):
592    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
593    # so in that case we can't extract the date.
594    try:
595        if cast.args["to"].this == exp.DataType.Type.DATE:
596            return datetime.date.fromisoformat(cast.name)
597        if cast.args["to"].this == exp.DataType.Type.DATETIME:
598            return datetime.datetime.fromisoformat(cast.name)
599    except ValueError:
600        return None
def extract_interval(interval):
603def extract_interval(interval):
604    try:
605        from dateutil.relativedelta import relativedelta  # type: ignore
606    except ModuleNotFoundError:
607        return None
608
609    n = int(interval.name)
610    unit = interval.text("unit").lower()
611
612    if unit == "year":
613        return relativedelta(years=n)
614    if unit == "month":
615        return relativedelta(months=n)
616    if unit == "week":
617        return relativedelta(weeks=n)
618    if unit == "day":
619        return relativedelta(days=n)
620    return None
def date_literal(date):
623def date_literal(date):
624    return exp.cast(
625        exp.Literal.string(date),
626        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
627    )
def boolean_literal(condition):
630def boolean_literal(condition):
631    return exp.true() if condition else exp.false()