Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
  11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  12
  13# Final means that an expression should not be simplified
  14FINAL = "final"
  15
  16
  17class UnsupportedUnit(Exception):
  18    pass
  19
  20
  21def simplify(expression, constant_propagation=False):
  22    """
  23    Rewrite sqlglot AST to simplify expressions.
  24
  25    Example:
  26        >>> import sqlglot
  27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  28        >>> simplify(expression).sql()
  29        'TRUE'
  30
  31    Args:
  32        expression (sqlglot.Expression): expression to simplify
  33        constant_propagation: whether or not the constant propagation rule should be used
  34
  35    Returns:
  36        sqlglot.Expression: simplified expression
  37    """
  38
  39    # group by expressions cannot be simplified, for example
  40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  41    # the projection must exactly match the group by key
  42    for group in expression.find_all(exp.Group):
  43        select = group.parent
  44        groups = set(group.expressions)
  45        group.meta[FINAL] = True
  46
  47        for e in select.selects:
  48            for node, *_ in e.walk():
  49                if node in groups:
  50                    e.meta[FINAL] = True
  51                    break
  52
  53        having = select.args.get("having")
  54        if having:
  55            for node, *_ in having.walk():
  56                if node in groups:
  57                    having.meta[FINAL] = True
  58                    break
  59
  60    def _simplify(expression, root=True):
  61        if expression.meta.get(FINAL):
  62            return expression
  63
  64        # Pre-order transformations
  65        node = expression
  66        node = rewrite_between(node)
  67        node = uniq_sort(node, root)
  68        node = absorb_and_eliminate(node, root)
  69        node = simplify_concat(node)
  70        node = simplify_conditionals(node)
  71
  72        if constant_propagation:
  73            node = propagate_constants(node, root)
  74
  75        exp.replace_children(node, lambda e: _simplify(e, False))
  76
  77        # Post-order transformations
  78        node = simplify_not(node)
  79        node = flatten(node)
  80        node = simplify_connectors(node, root)
  81        node = remove_complements(node, root)
  82        node = simplify_coalesce(node)
  83        node.parent = expression.parent
  84        node = simplify_literals(node, root)
  85        node = simplify_equality(node)
  86        node = simplify_parens(node)
  87        node = simplify_datetrunc_predicate(node)
  88
  89        if root:
  90            expression.replace(node)
  91
  92        return node
  93
  94    expression = while_changing(expression, _simplify)
  95    remove_where_true(expression)
  96    return expression
  97
  98
  99def catch(*exceptions):
 100    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 101
 102    def decorator(func):
 103        def wrapped(expression, *args, **kwargs):
 104            try:
 105                return func(expression, *args, **kwargs)
 106            except exceptions:
 107                return expression
 108
 109        return wrapped
 110
 111    return decorator
 112
 113
 114def rewrite_between(expression: exp.Expression) -> exp.Expression:
 115    """Rewrite x between y and z to x >= y AND x <= z.
 116
 117    This is done because comparison simplification is only done on lt/lte/gt/gte.
 118    """
 119    if isinstance(expression, exp.Between):
 120        return exp.and_(
 121            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 122            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 123            copy=False,
 124        )
 125    return expression
 126
 127
 128def simplify_not(expression):
 129    """
 130    Demorgan's Law
 131    NOT (x OR y) -> NOT x AND NOT y
 132    NOT (x AND y) -> NOT x OR NOT y
 133    """
 134    if isinstance(expression, exp.Not):
 135        if is_null(expression.this):
 136            return exp.null()
 137        if isinstance(expression.this, exp.Paren):
 138            condition = expression.this.unnest()
 139            if isinstance(condition, exp.And):
 140                return exp.or_(
 141                    exp.not_(condition.left, copy=False),
 142                    exp.not_(condition.right, copy=False),
 143                    copy=False,
 144                )
 145            if isinstance(condition, exp.Or):
 146                return exp.and_(
 147                    exp.not_(condition.left, copy=False),
 148                    exp.not_(condition.right, copy=False),
 149                    copy=False,
 150                )
 151            if is_null(condition):
 152                return exp.null()
 153        if always_true(expression.this):
 154            return exp.false()
 155        if is_false(expression.this):
 156            return exp.true()
 157        if isinstance(expression.this, exp.Not):
 158            # double negation
 159            # NOT NOT x -> x
 160            return expression.this.this
 161    return expression
 162
 163
 164def flatten(expression):
 165    """
 166    A AND (B AND C) -> A AND B AND C
 167    A OR (B OR C) -> A OR B OR C
 168    """
 169    if isinstance(expression, exp.Connector):
 170        for node in expression.args.values():
 171            child = node.unnest()
 172            if isinstance(child, expression.__class__):
 173                node.replace(child)
 174    return expression
 175
 176
 177def simplify_connectors(expression, root=True):
 178    def _simplify_connectors(expression, left, right):
 179        if left == right:
 180            return left
 181        if isinstance(expression, exp.And):
 182            if is_false(left) or is_false(right):
 183                return exp.false()
 184            if is_null(left) or is_null(right):
 185                return exp.null()
 186            if always_true(left) and always_true(right):
 187                return exp.true()
 188            if always_true(left):
 189                return right
 190            if always_true(right):
 191                return left
 192            return _simplify_comparison(expression, left, right)
 193        elif isinstance(expression, exp.Or):
 194            if always_true(left) or always_true(right):
 195                return exp.true()
 196            if is_false(left) and is_false(right):
 197                return exp.false()
 198            if (
 199                (is_null(left) and is_null(right))
 200                or (is_null(left) and is_false(right))
 201                or (is_false(left) and is_null(right))
 202            ):
 203                return exp.null()
 204            if is_false(left):
 205                return right
 206            if is_false(right):
 207                return left
 208            return _simplify_comparison(expression, left, right, or_=True)
 209
 210    if isinstance(expression, exp.Connector):
 211        return _flat_simplify(expression, _simplify_connectors, root)
 212    return expression
 213
 214
 215LT_LTE = (exp.LT, exp.LTE)
 216GT_GTE = (exp.GT, exp.GTE)
 217
 218COMPARISONS = (
 219    *LT_LTE,
 220    *GT_GTE,
 221    exp.EQ,
 222    exp.NEQ,
 223    exp.Is,
 224)
 225
 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 227    exp.LT: exp.GT,
 228    exp.GT: exp.LT,
 229    exp.LTE: exp.GTE,
 230    exp.GTE: exp.LTE,
 231}
 232
 233
 234def _simplify_comparison(expression, left, right, or_=False):
 235    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 236        ll, lr = left.args.values()
 237        rl, rr = right.args.values()
 238
 239        largs = {ll, lr}
 240        rargs = {rl, rr}
 241
 242        matching = largs & rargs
 243        columns = {m for m in matching if isinstance(m, exp.Column)}
 244
 245        if matching and columns:
 246            try:
 247                l = first(largs - columns)
 248                r = first(rargs - columns)
 249            except StopIteration:
 250                return expression
 251
 252            # make sure the comparison is always of the form x > 1 instead of 1 < x
 253            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 254                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 255            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 256                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 257
 258            if l.is_number and r.is_number:
 259                l = float(l.name)
 260                r = float(r.name)
 261            elif l.is_string and r.is_string:
 262                l = l.name
 263                r = r.name
 264            else:
 265                return None
 266
 267            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 268                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 269                    return left if (av > bv if or_ else av <= bv) else right
 270                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 271                    return left if (av < bv if or_ else av >= bv) else right
 272
 273                # we can't ever shortcut to true because the column could be null
 274                if not or_:
 275                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 276                        if av <= bv:
 277                            return exp.false()
 278                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 279                        if av >= bv:
 280                            return exp.false()
 281                    elif isinstance(a, exp.EQ):
 282                        if isinstance(b, exp.LT):
 283                            return exp.false() if av >= bv else a
 284                        if isinstance(b, exp.LTE):
 285                            return exp.false() if av > bv else a
 286                        if isinstance(b, exp.GT):
 287                            return exp.false() if av <= bv else a
 288                        if isinstance(b, exp.GTE):
 289                            return exp.false() if av < bv else a
 290                        if isinstance(b, exp.NEQ):
 291                            return exp.false() if av == bv else a
 292    return None
 293
 294
 295def remove_complements(expression, root=True):
 296    """
 297    Removing complements.
 298
 299    A AND NOT A -> FALSE
 300    A OR NOT A -> TRUE
 301    """
 302    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 303        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 304
 305        for a, b in itertools.permutations(expression.flatten(), 2):
 306            if is_complement(a, b):
 307                return complement
 308    return expression
 309
 310
 311def uniq_sort(expression, root=True):
 312    """
 313    Uniq and sort a connector.
 314
 315    C AND A AND B AND B -> A AND B AND C
 316    """
 317    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 318        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 319        flattened = tuple(expression.flatten())
 320        deduped = {gen(e): e for e in flattened}
 321        arr = tuple(deduped.items())
 322
 323        # check if the operands are already sorted, if not sort them
 324        # A AND C AND B -> A AND B AND C
 325        for i, (sql, e) in enumerate(arr[1:]):
 326            if sql < arr[i][0]:
 327                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 328                break
 329        else:
 330            # we didn't have to sort but maybe we need to dedup
 331            if len(deduped) < len(flattened):
 332                expression = result_func(*deduped.values(), copy=False)
 333
 334    return expression
 335
 336
 337def absorb_and_eliminate(expression, root=True):
 338    """
 339    absorption:
 340        A AND (A OR B) -> A
 341        A OR (A AND B) -> A
 342        A AND (NOT A OR B) -> A AND B
 343        A OR (NOT A AND B) -> A OR B
 344    elimination:
 345        (A AND B) OR (A AND NOT B) -> A
 346        (A OR B) AND (A OR NOT B) -> A
 347    """
 348    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 349        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 350
 351        for a, b in itertools.permutations(expression.flatten(), 2):
 352            if isinstance(a, kind):
 353                aa, ab = a.unnest_operands()
 354
 355                # absorb
 356                if is_complement(b, aa):
 357                    aa.replace(exp.true() if kind == exp.And else exp.false())
 358                elif is_complement(b, ab):
 359                    ab.replace(exp.true() if kind == exp.And else exp.false())
 360                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 361                    a.replace(exp.false() if kind == exp.And else exp.true())
 362                elif isinstance(b, kind):
 363                    # eliminate
 364                    rhs = b.unnest_operands()
 365                    ba, bb = rhs
 366
 367                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 368                        a.replace(aa)
 369                        b.replace(aa)
 370                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 371                        a.replace(ab)
 372                        b.replace(ab)
 373
 374    return expression
 375
 376
 377def propagate_constants(expression, root=True):
 378    """
 379    Propagate constants for conjunctions in DNF:
 380
 381    SELECT * FROM t WHERE a = b AND b = 5 becomes
 382    SELECT * FROM t WHERE a = 5 AND b = 5
 383
 384    Reference: https://www.sqlite.org/optoverview.html
 385    """
 386
 387    if (
 388        isinstance(expression, exp.And)
 389        and (root or not expression.same_parent)
 390        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 391    ):
 392        constant_mapping = {}
 393        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 394            if isinstance(expr, exp.EQ):
 395                l, r = expr.left, expr.right
 396
 397                # TODO: create a helper that can be used to detect nested literal expressions such
 398                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 399                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 400                    pass
 401                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 402                    l, r = r, l
 403                else:
 404                    continue
 405
 406                constant_mapping[l] = (id(l), r)
 407
 408        if constant_mapping:
 409            for column in find_all_in_scope(expression, exp.Column):
 410                parent = column.parent
 411                column_id, constant = constant_mapping.get(column) or (None, None)
 412                if (
 413                    column_id is not None
 414                    and id(column) != column_id
 415                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 416                ):
 417                    column.replace(constant.copy())
 418
 419    return expression
 420
 421
 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 423    exp.DateAdd: exp.Sub,
 424    exp.DateSub: exp.Add,
 425    exp.DatetimeAdd: exp.Sub,
 426    exp.DatetimeSub: exp.Add,
 427}
 428
 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 430    **INVERSE_DATE_OPS,
 431    exp.Add: exp.Sub,
 432    exp.Sub: exp.Add,
 433}
 434
 435
 436def _is_number(expression: exp.Expression) -> bool:
 437    return expression.is_number
 438
 439
 440def _is_interval(expression: exp.Expression) -> bool:
 441    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 442
 443
 444@catch(ModuleNotFoundError, UnsupportedUnit)
 445def simplify_equality(expression: exp.Expression) -> exp.Expression:
 446    """
 447    Use the subtraction and addition properties of equality to simplify expressions:
 448
 449        x + 1 = 3 becomes x = 2
 450
 451    There are two binary operations in the above expression: + and =
 452    Here's how we reference all the operands in the code below:
 453
 454          l     r
 455        x + 1 = 3
 456        a   b
 457    """
 458    if isinstance(expression, COMPARISONS):
 459        l, r = expression.left, expression.right
 460
 461        if l.__class__ in INVERSE_OPS:
 462            pass
 463        elif r.__class__ in INVERSE_OPS:
 464            l, r = r, l
 465        else:
 466            return expression
 467
 468        if r.is_number:
 469            a_predicate = _is_number
 470            b_predicate = _is_number
 471        elif _is_date_literal(r):
 472            a_predicate = _is_date_literal
 473            b_predicate = _is_interval
 474        else:
 475            return expression
 476
 477        if l.__class__ in INVERSE_DATE_OPS:
 478            l = t.cast(exp.IntervalOp, l)
 479            a = l.this
 480            b = l.interval()
 481        else:
 482            l = t.cast(exp.Binary, l)
 483            a, b = l.left, l.right
 484
 485        if not a_predicate(a) and b_predicate(b):
 486            pass
 487        elif not a_predicate(b) and b_predicate(a):
 488            a, b = b, a
 489        else:
 490            return expression
 491
 492        return expression.__class__(
 493            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 494        )
 495    return expression
 496
 497
 498def simplify_literals(expression, root=True):
 499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 500        return _flat_simplify(expression, _simplify_binary, root)
 501
 502    if isinstance(expression, exp.Neg):
 503        this = expression.this
 504        if this.is_number:
 505            value = this.name
 506            if value[0] == "-":
 507                return exp.Literal.number(value[1:])
 508            return exp.Literal.number(f"-{value}")
 509
 510    return expression
 511
 512
 513def _simplify_binary(expression, a, b):
 514    if isinstance(expression, exp.Is):
 515        if isinstance(b, exp.Not):
 516            c = b.this
 517            not_ = True
 518        else:
 519            c = b
 520            not_ = False
 521
 522        if is_null(c):
 523            if isinstance(a, exp.Literal):
 524                return exp.true() if not_ else exp.false()
 525            if is_null(a):
 526                return exp.false() if not_ else exp.true()
 527    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 528        return None
 529    elif is_null(a) or is_null(b):
 530        return exp.null()
 531
 532    if a.is_number and b.is_number:
 533        a = int(a.name) if a.is_int else Decimal(a.name)
 534        b = int(b.name) if b.is_int else Decimal(b.name)
 535
 536        if isinstance(expression, exp.Add):
 537            return exp.Literal.number(a + b)
 538        if isinstance(expression, exp.Sub):
 539            return exp.Literal.number(a - b)
 540        if isinstance(expression, exp.Mul):
 541            return exp.Literal.number(a * b)
 542        if isinstance(expression, exp.Div):
 543            # engines have differing int div behavior so intdiv is not safe
 544            if isinstance(a, int) and isinstance(b, int):
 545                return None
 546            return exp.Literal.number(a / b)
 547
 548        boolean = eval_boolean(expression, a, b)
 549
 550        if boolean:
 551            return boolean
 552    elif a.is_string and b.is_string:
 553        boolean = eval_boolean(expression, a.this, b.this)
 554
 555        if boolean:
 556            return boolean
 557    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 558        a, b = extract_date(a), extract_interval(b)
 559        if a and b:
 560            if isinstance(expression, exp.Add):
 561                return date_literal(a + b)
 562            if isinstance(expression, exp.Sub):
 563                return date_literal(a - b)
 564    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 565        a, b = extract_interval(a), extract_date(b)
 566        # you cannot subtract a date from an interval
 567        if a and b and isinstance(expression, exp.Add):
 568            return date_literal(a + b)
 569    elif _is_date_literal(a) and _is_date_literal(b):
 570        if isinstance(expression, exp.Predicate):
 571            a, b = extract_date(a), extract_date(b)
 572            boolean = eval_boolean(expression, a, b)
 573            if boolean:
 574                return boolean
 575
 576    return None
 577
 578
 579def simplify_parens(expression):
 580    if not isinstance(expression, exp.Paren):
 581        return expression
 582
 583    this = expression.this
 584    parent = expression.parent
 585
 586    if not isinstance(this, exp.Select) and (
 587        not isinstance(parent, (exp.Condition, exp.Binary))
 588        or isinstance(parent, exp.Paren)
 589        or not isinstance(this, exp.Binary)
 590        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 591        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 592        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 593        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 594    ):
 595        return this
 596    return expression
 597
 598
 599NONNULL_CONSTANTS = (
 600    exp.Literal,
 601    exp.Boolean,
 602)
 603
 604CONSTANTS = (
 605    exp.Literal,
 606    exp.Boolean,
 607    exp.Null,
 608)
 609
 610
 611def _is_nonnull_constant(expression: exp.Expression) -> bool:
 612    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 613
 614
 615def _is_constant(expression: exp.Expression) -> bool:
 616    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 617
 618
 619def simplify_coalesce(expression):
 620    # COALESCE(x) -> x
 621    if (
 622        isinstance(expression, exp.Coalesce)
 623        and (not expression.expressions or _is_nonnull_constant(expression.this))
 624        # COALESCE is also used as a Spark partitioning hint
 625        and not isinstance(expression.parent, exp.Hint)
 626    ):
 627        return expression.this
 628
 629    if not isinstance(expression, COMPARISONS):
 630        return expression
 631
 632    if isinstance(expression.left, exp.Coalesce):
 633        coalesce = expression.left
 634        other = expression.right
 635    elif isinstance(expression.right, exp.Coalesce):
 636        coalesce = expression.right
 637        other = expression.left
 638    else:
 639        return expression
 640
 641    # This transformation is valid for non-constants,
 642    # but it really only does anything if they are both constants.
 643    if not _is_constant(other):
 644        return expression
 645
 646    # Find the first constant arg
 647    for arg_index, arg in enumerate(coalesce.expressions):
 648        if _is_constant(other):
 649            break
 650    else:
 651        return expression
 652
 653    coalesce.set("expressions", coalesce.expressions[:arg_index])
 654
 655    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 656    # since we already remove COALESCE at the top of this function.
 657    coalesce = coalesce if coalesce.expressions else coalesce.this
 658
 659    # This expression is more complex than when we started, but it will get simplified further
 660    return exp.paren(
 661        exp.or_(
 662            exp.and_(
 663                coalesce.is_(exp.null()).not_(copy=False),
 664                expression.copy(),
 665                copy=False,
 666            ),
 667            exp.and_(
 668                coalesce.is_(exp.null()),
 669                type(expression)(this=arg.copy(), expression=other.copy()),
 670                copy=False,
 671            ),
 672            copy=False,
 673        )
 674    )
 675
 676
 677CONCATS = (exp.Concat, exp.DPipe)
 678SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
 679
 680
 681def simplify_concat(expression):
 682    """Reduces all groups that contain string literals by concatenating them."""
 683    if not isinstance(expression, CONCATS) or (
 684        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 685        isinstance(expression, exp.ConcatWs)
 686        and not expression.expressions[0].is_string
 687    ):
 688        return expression
 689
 690    if isinstance(expression, exp.ConcatWs):
 691        sep_expr, *expressions = expression.expressions
 692        sep = sep_expr.name
 693        concat_type = exp.ConcatWs
 694    else:
 695        expressions = expression.expressions
 696        sep = ""
 697        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
 698
 699    new_args = []
 700    for is_string_group, group in itertools.groupby(
 701        expressions or expression.flatten(), lambda e: e.is_string
 702    ):
 703        if is_string_group:
 704            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 705        else:
 706            new_args.extend(group)
 707
 708    if len(new_args) == 1 and new_args[0].is_string:
 709        return new_args[0]
 710
 711    if concat_type is exp.ConcatWs:
 712        new_args = [sep_expr] + new_args
 713
 714    return concat_type(expressions=new_args)
 715
 716
 717def simplify_conditionals(expression):
 718    """Simplifies expressions like IF, CASE if their condition is statically known."""
 719    if isinstance(expression, exp.Case):
 720        this = expression.this
 721        for case in expression.args["ifs"]:
 722            cond = case.this
 723            if this:
 724                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 725                cond = cond.replace(this.pop().eq(cond))
 726
 727            if always_true(cond):
 728                return case.args["true"]
 729
 730            if always_false(cond):
 731                case.pop()
 732                if not expression.args["ifs"]:
 733                    return expression.args.get("default") or exp.null()
 734    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 735        if always_true(expression.this):
 736            return expression.args["true"]
 737        if always_false(expression.this):
 738            return expression.args.get("false") or exp.null()
 739
 740    return expression
 741
 742
 743DateRange = t.Tuple[datetime.date, datetime.date]
 744
 745
 746def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 747    """
 748    Get the date range for a DATE_TRUNC equality comparison:
 749
 750    Example:
 751        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 752    Returns:
 753        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 754    """
 755    floor = date_floor(date, unit)
 756
 757    if date != floor:
 758        # This will always be False, except for NULL values.
 759        return None
 760
 761    return floor, floor + interval(unit)
 762
 763
 764def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 765    """Get the logical expression for a date range"""
 766    return exp.and_(
 767        left >= date_literal(drange[0]),
 768        left < date_literal(drange[1]),
 769        copy=False,
 770    )
 771
 772
 773def _datetrunc_eq(
 774    left: exp.Expression, date: datetime.date, unit: str
 775) -> t.Optional[exp.Expression]:
 776    drange = _datetrunc_range(date, unit)
 777    if not drange:
 778        return None
 779
 780    return _datetrunc_eq_expression(left, drange)
 781
 782
 783def _datetrunc_neq(
 784    left: exp.Expression, date: datetime.date, unit: str
 785) -> t.Optional[exp.Expression]:
 786    drange = _datetrunc_range(date, unit)
 787    if not drange:
 788        return None
 789
 790    return exp.and_(
 791        left < date_literal(drange[0]),
 792        left >= date_literal(drange[1]),
 793        copy=False,
 794    )
 795
 796
 797DateTruncBinaryTransform = t.Callable[
 798    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 799]
 800DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 801    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 802    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 803    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 804    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 805    exp.EQ: _datetrunc_eq,
 806    exp.NEQ: _datetrunc_neq,
 807}
 808DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 809
 810
 811def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 812    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 813
 814
 815@catch(ModuleNotFoundError, UnsupportedUnit)
 816def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 817    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 818    comparison = expression.__class__
 819
 820    if comparison not in DATETRUNC_COMPARISONS:
 821        return expression
 822
 823    if isinstance(expression, exp.Binary):
 824        l, r = expression.left, expression.right
 825
 826        if _is_datetrunc_predicate(l, r):
 827            pass
 828        elif _is_datetrunc_predicate(r, l):
 829            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 830            l, r = r, l
 831        else:
 832            return expression
 833
 834        l = t.cast(exp.DateTrunc, l)
 835        unit = l.unit.name.lower()
 836        date = extract_date(r)
 837
 838        if not date:
 839            return expression
 840
 841        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 842    elif isinstance(expression, exp.In):
 843        l = expression.this
 844        rs = expression.expressions
 845
 846        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 847            l = t.cast(exp.DateTrunc, l)
 848            unit = l.unit.name.lower()
 849
 850            ranges = []
 851            for r in rs:
 852                date = extract_date(r)
 853                if not date:
 854                    return expression
 855                drange = _datetrunc_range(date, unit)
 856                if drange:
 857                    ranges.append(drange)
 858
 859            if not ranges:
 860                return expression
 861
 862            ranges = merge_ranges(ranges)
 863
 864            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 865
 866    return expression
 867
 868
 869# CROSS joins result in an empty table if the right table is empty.
 870# So we can only simplify certain types of joins to CROSS.
 871# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 872JOINS = {
 873    ("", ""),
 874    ("", "INNER"),
 875    ("RIGHT", ""),
 876    ("RIGHT", "OUTER"),
 877}
 878
 879
 880def remove_where_true(expression):
 881    for where in expression.find_all(exp.Where):
 882        if always_true(where.this):
 883            where.parent.set("where", None)
 884    for join in expression.find_all(exp.Join):
 885        if (
 886            always_true(join.args.get("on"))
 887            and not join.args.get("using")
 888            and not join.args.get("method")
 889            and (join.side, join.kind) in JOINS
 890        ):
 891            join.set("on", None)
 892            join.set("side", None)
 893            join.set("kind", "CROSS")
 894
 895
 896def always_true(expression):
 897    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 898        expression, exp.Literal
 899    )
 900
 901
 902def always_false(expression):
 903    return is_false(expression) or is_null(expression)
 904
 905
 906def is_complement(a, b):
 907    return isinstance(b, exp.Not) and b.this == a
 908
 909
 910def is_false(a: exp.Expression) -> bool:
 911    return type(a) is exp.Boolean and not a.this
 912
 913
 914def is_null(a: exp.Expression) -> bool:
 915    return type(a) is exp.Null
 916
 917
 918def eval_boolean(expression, a, b):
 919    if isinstance(expression, (exp.EQ, exp.Is)):
 920        return boolean_literal(a == b)
 921    if isinstance(expression, exp.NEQ):
 922        return boolean_literal(a != b)
 923    if isinstance(expression, exp.GT):
 924        return boolean_literal(a > b)
 925    if isinstance(expression, exp.GTE):
 926        return boolean_literal(a >= b)
 927    if isinstance(expression, exp.LT):
 928        return boolean_literal(a < b)
 929    if isinstance(expression, exp.LTE):
 930        return boolean_literal(a <= b)
 931    return None
 932
 933
 934def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 935    if isinstance(value, datetime.datetime):
 936        return value.date()
 937    if isinstance(value, datetime.date):
 938        return value
 939    try:
 940        return datetime.datetime.fromisoformat(value).date()
 941    except ValueError:
 942        return None
 943
 944
 945def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 946    if isinstance(value, datetime.datetime):
 947        return value
 948    if isinstance(value, datetime.date):
 949        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 950    try:
 951        return datetime.datetime.fromisoformat(value)
 952    except ValueError:
 953        return None
 954
 955
 956def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 957    if not value:
 958        return None
 959    if to.is_type(exp.DataType.Type.DATE):
 960        return cast_as_date(value)
 961    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 962        return cast_as_datetime(value)
 963    return None
 964
 965
 966def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 967    if isinstance(cast, exp.Cast):
 968        to = cast.to
 969    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
 970        to = exp.DataType.build(exp.DataType.Type.DATE)
 971    else:
 972        return None
 973
 974    if isinstance(cast.this, exp.Literal):
 975        value: t.Any = cast.this.name
 976    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 977        value = extract_date(cast.this)
 978    else:
 979        return None
 980    return cast_value(value, to)
 981
 982
 983def _is_date_literal(expression: exp.Expression) -> bool:
 984    return extract_date(expression) is not None
 985
 986
 987def extract_interval(expression):
 988    n = int(expression.name)
 989    unit = expression.text("unit").lower()
 990
 991    try:
 992        return interval(unit, n)
 993    except (UnsupportedUnit, ModuleNotFoundError):
 994        return None
 995
 996
 997def date_literal(date):
 998    return exp.cast(
 999        exp.Literal.string(date),
1000        exp.DataType.Type.DATETIME
1001        if isinstance(date, datetime.datetime)
1002        else exp.DataType.Type.DATE,
1003    )
1004
1005
1006def interval(unit: str, n: int = 1):
1007    from dateutil.relativedelta import relativedelta
1008
1009    if unit == "year":
1010        return relativedelta(years=1 * n)
1011    if unit == "quarter":
1012        return relativedelta(months=3 * n)
1013    if unit == "month":
1014        return relativedelta(months=1 * n)
1015    if unit == "week":
1016        return relativedelta(weeks=1 * n)
1017    if unit == "day":
1018        return relativedelta(days=1 * n)
1019    if unit == "hour":
1020        return relativedelta(hours=1 * n)
1021    if unit == "minute":
1022        return relativedelta(minutes=1 * n)
1023    if unit == "second":
1024        return relativedelta(seconds=1 * n)
1025
1026    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1027
1028
1029def date_floor(d: datetime.date, unit: str) -> datetime.date:
1030    if unit == "year":
1031        return d.replace(month=1, day=1)
1032    if unit == "quarter":
1033        if d.month <= 3:
1034            return d.replace(month=1, day=1)
1035        elif d.month <= 6:
1036            return d.replace(month=4, day=1)
1037        elif d.month <= 9:
1038            return d.replace(month=7, day=1)
1039        else:
1040            return d.replace(month=10, day=1)
1041    if unit == "month":
1042        return d.replace(month=d.month, day=1)
1043    if unit == "week":
1044        # Assuming week starts on Monday (0) and ends on Sunday (6)
1045        return d - datetime.timedelta(days=d.weekday())
1046    if unit == "day":
1047        return d
1048
1049    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1050
1051
1052def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1053    floor = date_floor(d, unit)
1054
1055    if floor == d:
1056        return d
1057
1058    return floor + interval(unit)
1059
1060
1061def boolean_literal(condition):
1062    return exp.true() if condition else exp.false()
1063
1064
1065def _flat_simplify(expression, simplifier, root=True):
1066    if root or not expression.same_parent:
1067        operands = []
1068        queue = deque(expression.flatten(unnest=False))
1069        size = len(queue)
1070
1071        while queue:
1072            a = queue.popleft()
1073
1074            for b in queue:
1075                result = simplifier(expression, a, b)
1076
1077                if result and result is not expression:
1078                    queue.remove(b)
1079                    queue.appendleft(result)
1080                    break
1081            else:
1082                operands.append(a)
1083
1084        if len(operands) < size:
1085            return functools.reduce(
1086                lambda a, b: expression.__class__(this=a, expression=b), operands
1087            )
1088    return expression
1089
1090
1091def gen(expression: t.Any) -> str:
1092    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1093
1094    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1095    generator is expensive so we have a bare minimum sql generator here.
1096    """
1097    if expression is None:
1098        return "_"
1099    if is_iterable(expression):
1100        return ",".join(gen(e) for e in expression)
1101    if not isinstance(expression, exp.Expression):
1102        return str(expression)
1103
1104    etype = type(expression)
1105    if etype in GEN_MAP:
1106        return GEN_MAP[etype](expression)
1107    return f"{expression.key} {gen(expression.args.values())}"
1108
1109
1110GEN_MAP = {
1111    exp.Add: lambda e: _binary(e, "+"),
1112    exp.And: lambda e: _binary(e, "AND"),
1113    exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
1114    exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
1115    exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
1116    exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
1117    exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
1118    exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
1119    exp.Div: lambda e: _binary(e, "/"),
1120    exp.Dot: lambda e: _binary(e, "."),
1121    exp.DPipe: lambda e: _binary(e, "||"),
1122    exp.SafeDPipe: lambda e: _binary(e, "||"),
1123    exp.EQ: lambda e: _binary(e, "="),
1124    exp.GT: lambda e: _binary(e, ">"),
1125    exp.GTE: lambda e: _binary(e, ">="),
1126    exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
1127    exp.ILike: lambda e: _binary(e, "ILIKE"),
1128    exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
1129    exp.Is: lambda e: _binary(e, "IS"),
1130    exp.Like: lambda e: _binary(e, "LIKE"),
1131    exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
1132    exp.LT: lambda e: _binary(e, "<"),
1133    exp.LTE: lambda e: _binary(e, "<="),
1134    exp.Mod: lambda e: _binary(e, "%"),
1135    exp.Mul: lambda e: _binary(e, "*"),
1136    exp.Neg: lambda e: _unary(e, "-"),
1137    exp.NEQ: lambda e: _binary(e, "<>"),
1138    exp.Not: lambda e: _unary(e, "NOT"),
1139    exp.Null: lambda e: "NULL",
1140    exp.Or: lambda e: _binary(e, "OR"),
1141    exp.Paren: lambda e: f"({gen(e.this)})",
1142    exp.Sub: lambda e: _binary(e, "-"),
1143    exp.Subquery: lambda e: f"({gen(e.args.values())})",
1144    exp.Table: lambda e: gen(e.args.values()),
1145    exp.Var: lambda e: e.name,
1146}
1147
1148
1149def _binary(e: exp.Binary, op: str) -> str:
1150    return f"{gen(e.left)} {op} {gen(e.right)}"
1151
1152
1153def _unary(e: exp.Unary, op: str) -> str:
1154    return f"{op} {gen(e.this)}"
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
18class UnsupportedUnit(Exception):
19    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
22def simplify(expression, constant_propagation=False):
23    """
24    Rewrite sqlglot AST to simplify expressions.
25
26    Example:
27        >>> import sqlglot
28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
29        >>> simplify(expression).sql()
30        'TRUE'
31
32    Args:
33        expression (sqlglot.Expression): expression to simplify
34        constant_propagation: whether or not the constant propagation rule should be used
35
36    Returns:
37        sqlglot.Expression: simplified expression
38    """
39
40    # group by expressions cannot be simplified, for example
41    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
42    # the projection must exactly match the group by key
43    for group in expression.find_all(exp.Group):
44        select = group.parent
45        groups = set(group.expressions)
46        group.meta[FINAL] = True
47
48        for e in select.selects:
49            for node, *_ in e.walk():
50                if node in groups:
51                    e.meta[FINAL] = True
52                    break
53
54        having = select.args.get("having")
55        if having:
56            for node, *_ in having.walk():
57                if node in groups:
58                    having.meta[FINAL] = True
59                    break
60
61    def _simplify(expression, root=True):
62        if expression.meta.get(FINAL):
63            return expression
64
65        # Pre-order transformations
66        node = expression
67        node = rewrite_between(node)
68        node = uniq_sort(node, root)
69        node = absorb_and_eliminate(node, root)
70        node = simplify_concat(node)
71        node = simplify_conditionals(node)
72
73        if constant_propagation:
74            node = propagate_constants(node, root)
75
76        exp.replace_children(node, lambda e: _simplify(e, False))
77
78        # Post-order transformations
79        node = simplify_not(node)
80        node = flatten(node)
81        node = simplify_connectors(node, root)
82        node = remove_complements(node, root)
83        node = simplify_coalesce(node)
84        node.parent = expression.parent
85        node = simplify_literals(node, root)
86        node = simplify_equality(node)
87        node = simplify_parens(node)
88        node = simplify_datetrunc_predicate(node)
89
90        if root:
91            expression.replace(node)
92
93        return node
94
95    expression = while_changing(expression, _simplify)
96    remove_where_true(expression)
97    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
100def catch(*exceptions):
101    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
102
103    def decorator(func):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
109
110        return wrapped
111
112    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
115def rewrite_between(expression: exp.Expression) -> exp.Expression:
116    """Rewrite x between y and z to x >= y AND x <= z.
117
118    This is done because comparison simplification is only done on lt/lte/gt/gte.
119    """
120    if isinstance(expression, exp.Between):
121        return exp.and_(
122            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
123            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
124            copy=False,
125        )
126    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
129def simplify_not(expression):
130    """
131    Demorgan's Law
132    NOT (x OR y) -> NOT x AND NOT y
133    NOT (x AND y) -> NOT x OR NOT y
134    """
135    if isinstance(expression, exp.Not):
136        if is_null(expression.this):
137            return exp.null()
138        if isinstance(expression.this, exp.Paren):
139            condition = expression.this.unnest()
140            if isinstance(condition, exp.And):
141                return exp.or_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if isinstance(condition, exp.Or):
147                return exp.and_(
148                    exp.not_(condition.left, copy=False),
149                    exp.not_(condition.right, copy=False),
150                    copy=False,
151                )
152            if is_null(condition):
153                return exp.null()
154        if always_true(expression.this):
155            return exp.false()
156        if is_false(expression.this):
157            return exp.true()
158        if isinstance(expression.this, exp.Not):
159            # double negation
160            # NOT NOT x -> x
161            return expression.this.this
162    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
165def flatten(expression):
166    """
167    A AND (B AND C) -> A AND B AND C
168    A OR (B OR C) -> A OR B OR C
169    """
170    if isinstance(expression, exp.Connector):
171        for node in expression.args.values():
172            child = node.unnest()
173            if isinstance(child, expression.__class__):
174                node.replace(child)
175    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
178def simplify_connectors(expression, root=True):
179    def _simplify_connectors(expression, left, right):
180        if left == right:
181            return left
182        if isinstance(expression, exp.And):
183            if is_false(left) or is_false(right):
184                return exp.false()
185            if is_null(left) or is_null(right):
186                return exp.null()
187            if always_true(left) and always_true(right):
188                return exp.true()
189            if always_true(left):
190                return right
191            if always_true(right):
192                return left
193            return _simplify_comparison(expression, left, right)
194        elif isinstance(expression, exp.Or):
195            if always_true(left) or always_true(right):
196                return exp.true()
197            if is_false(left) and is_false(right):
198                return exp.false()
199            if (
200                (is_null(left) and is_null(right))
201                or (is_null(left) and is_false(right))
202                or (is_false(left) and is_null(right))
203            ):
204                return exp.null()
205            if is_false(left):
206                return right
207            if is_false(right):
208                return left
209            return _simplify_comparison(expression, left, right, or_=True)
210
211    if isinstance(expression, exp.Connector):
212        return _flat_simplify(expression, _simplify_connectors, root)
213    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
296def remove_complements(expression, root=True):
297    """
298    Removing complements.
299
300    A AND NOT A -> FALSE
301    A OR NOT A -> TRUE
302    """
303    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
304        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
305
306        for a, b in itertools.permutations(expression.flatten(), 2):
307            if is_complement(a, b):
308                return complement
309    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
312def uniq_sort(expression, root=True):
313    """
314    Uniq and sort a connector.
315
316    C AND A AND B AND B -> A AND B AND C
317    """
318    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
319        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
320        flattened = tuple(expression.flatten())
321        deduped = {gen(e): e for e in flattened}
322        arr = tuple(deduped.items())
323
324        # check if the operands are already sorted, if not sort them
325        # A AND C AND B -> A AND B AND C
326        for i, (sql, e) in enumerate(arr[1:]):
327            if sql < arr[i][0]:
328                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
329                break
330        else:
331            # we didn't have to sort but maybe we need to dedup
332            if len(deduped) < len(flattened):
333                expression = result_func(*deduped.values(), copy=False)
334
335    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
338def absorb_and_eliminate(expression, root=True):
339    """
340    absorption:
341        A AND (A OR B) -> A
342        A OR (A AND B) -> A
343        A AND (NOT A OR B) -> A AND B
344        A OR (NOT A AND B) -> A OR B
345    elimination:
346        (A AND B) OR (A AND NOT B) -> A
347        (A OR B) AND (A OR NOT B) -> A
348    """
349    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
350        kind = exp.Or if isinstance(expression, exp.And) else exp.And
351
352        for a, b in itertools.permutations(expression.flatten(), 2):
353            if isinstance(a, kind):
354                aa, ab = a.unnest_operands()
355
356                # absorb
357                if is_complement(b, aa):
358                    aa.replace(exp.true() if kind == exp.And else exp.false())
359                elif is_complement(b, ab):
360                    ab.replace(exp.true() if kind == exp.And else exp.false())
361                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
362                    a.replace(exp.false() if kind == exp.And else exp.true())
363                elif isinstance(b, kind):
364                    # eliminate
365                    rhs = b.unnest_operands()
366                    ba, bb = rhs
367
368                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
369                        a.replace(aa)
370                        b.replace(aa)
371                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
372                        a.replace(ab)
373                        b.replace(ab)
374
375    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
378def propagate_constants(expression, root=True):
379    """
380    Propagate constants for conjunctions in DNF:
381
382    SELECT * FROM t WHERE a = b AND b = 5 becomes
383    SELECT * FROM t WHERE a = 5 AND b = 5
384
385    Reference: https://www.sqlite.org/optoverview.html
386    """
387
388    if (
389        isinstance(expression, exp.And)
390        and (root or not expression.same_parent)
391        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
392    ):
393        constant_mapping = {}
394        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
395            if isinstance(expr, exp.EQ):
396                l, r = expr.left, expr.right
397
398                # TODO: create a helper that can be used to detect nested literal expressions such
399                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
400                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
401                    pass
402                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
403                    l, r = r, l
404                else:
405                    continue
406
407                constant_mapping[l] = (id(l), r)
408
409        if constant_mapping:
410            for column in find_all_in_scope(expression, exp.Column):
411                parent = column.parent
412                column_id, constant = constant_mapping.get(column) or (None, None)
413                if (
414                    column_id is not None
415                    and id(column) != column_id
416                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
417                ):
418                    column.replace(constant.copy())
419
420    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
499def simplify_literals(expression, root=True):
500    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
501        return _flat_simplify(expression, _simplify_binary, root)
502
503    if isinstance(expression, exp.Neg):
504        this = expression.this
505        if this.is_number:
506            value = this.name
507            if value[0] == "-":
508                return exp.Literal.number(value[1:])
509            return exp.Literal.number(f"-{value}")
510
511    return expression
def simplify_parens(expression):
580def simplify_parens(expression):
581    if not isinstance(expression, exp.Paren):
582        return expression
583
584    this = expression.this
585    parent = expression.parent
586
587    if not isinstance(this, exp.Select) and (
588        not isinstance(parent, (exp.Condition, exp.Binary))
589        or isinstance(parent, exp.Paren)
590        or not isinstance(this, exp.Binary)
591        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
592        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
593        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
594        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
595    ):
596        return this
597    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
620def simplify_coalesce(expression):
621    # COALESCE(x) -> x
622    if (
623        isinstance(expression, exp.Coalesce)
624        and (not expression.expressions or _is_nonnull_constant(expression.this))
625        # COALESCE is also used as a Spark partitioning hint
626        and not isinstance(expression.parent, exp.Hint)
627    ):
628        return expression.this
629
630    if not isinstance(expression, COMPARISONS):
631        return expression
632
633    if isinstance(expression.left, exp.Coalesce):
634        coalesce = expression.left
635        other = expression.right
636    elif isinstance(expression.right, exp.Coalesce):
637        coalesce = expression.right
638        other = expression.left
639    else:
640        return expression
641
642    # This transformation is valid for non-constants,
643    # but it really only does anything if they are both constants.
644    if not _is_constant(other):
645        return expression
646
647    # Find the first constant arg
648    for arg_index, arg in enumerate(coalesce.expressions):
649        if _is_constant(other):
650            break
651    else:
652        return expression
653
654    coalesce.set("expressions", coalesce.expressions[:arg_index])
655
656    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
657    # since we already remove COALESCE at the top of this function.
658    coalesce = coalesce if coalesce.expressions else coalesce.this
659
660    # This expression is more complex than when we started, but it will get simplified further
661    return exp.paren(
662        exp.or_(
663            exp.and_(
664                coalesce.is_(exp.null()).not_(copy=False),
665                expression.copy(),
666                copy=False,
667            ),
668            exp.and_(
669                coalesce.is_(exp.null()),
670                type(expression)(this=arg.copy(), expression=other.copy()),
671                copy=False,
672            ),
673            copy=False,
674        )
675    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
682def simplify_concat(expression):
683    """Reduces all groups that contain string literals by concatenating them."""
684    if not isinstance(expression, CONCATS) or (
685        # We can't reduce a CONCAT_WS call if we don't statically know the separator
686        isinstance(expression, exp.ConcatWs)
687        and not expression.expressions[0].is_string
688    ):
689        return expression
690
691    if isinstance(expression, exp.ConcatWs):
692        sep_expr, *expressions = expression.expressions
693        sep = sep_expr.name
694        concat_type = exp.ConcatWs
695    else:
696        expressions = expression.expressions
697        sep = ""
698        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
699
700    new_args = []
701    for is_string_group, group in itertools.groupby(
702        expressions or expression.flatten(), lambda e: e.is_string
703    ):
704        if is_string_group:
705            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
706        else:
707            new_args.extend(group)
708
709    if len(new_args) == 1 and new_args[0].is_string:
710        return new_args[0]
711
712    if concat_type is exp.ConcatWs:
713        new_args = [sep_expr] + new_args
714
715    return concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
718def simplify_conditionals(expression):
719    """Simplifies expressions like IF, CASE if their condition is statically known."""
720    if isinstance(expression, exp.Case):
721        this = expression.this
722        for case in expression.args["ifs"]:
723            cond = case.this
724            if this:
725                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
726                cond = cond.replace(this.pop().eq(cond))
727
728            if always_true(cond):
729                return case.args["true"]
730
731            if always_false(cond):
732                case.pop()
733                if not expression.args["ifs"]:
734                    return expression.args.get("default") or exp.null()
735    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
736        if always_true(expression.this):
737            return expression.args["true"]
738        if always_false(expression.this):
739            return expression.args.get("false") or exp.null()
740
741    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', 'OUTER'), ('', 'INNER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
881def remove_where_true(expression):
882    for where in expression.find_all(exp.Where):
883        if always_true(where.this):
884            where.parent.set("where", None)
885    for join in expression.find_all(exp.Join):
886        if (
887            always_true(join.args.get("on"))
888            and not join.args.get("using")
889            and not join.args.get("method")
890            and (join.side, join.kind) in JOINS
891        ):
892            join.set("on", None)
893            join.set("side", None)
894            join.set("kind", "CROSS")
def always_true(expression):
897def always_true(expression):
898    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
899        expression, exp.Literal
900    )
def always_false(expression):
903def always_false(expression):
904    return is_false(expression) or is_null(expression)
def is_complement(a, b):
907def is_complement(a, b):
908    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
911def is_false(a: exp.Expression) -> bool:
912    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
915def is_null(a: exp.Expression) -> bool:
916    return type(a) is exp.Null
def eval_boolean(expression, a, b):
919def eval_boolean(expression, a, b):
920    if isinstance(expression, (exp.EQ, exp.Is)):
921        return boolean_literal(a == b)
922    if isinstance(expression, exp.NEQ):
923        return boolean_literal(a != b)
924    if isinstance(expression, exp.GT):
925        return boolean_literal(a > b)
926    if isinstance(expression, exp.GTE):
927        return boolean_literal(a >= b)
928    if isinstance(expression, exp.LT):
929        return boolean_literal(a < b)
930    if isinstance(expression, exp.LTE):
931        return boolean_literal(a <= b)
932    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
935def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
936    if isinstance(value, datetime.datetime):
937        return value.date()
938    if isinstance(value, datetime.date):
939        return value
940    try:
941        return datetime.datetime.fromisoformat(value).date()
942    except ValueError:
943        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
946def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
947    if isinstance(value, datetime.datetime):
948        return value
949    if isinstance(value, datetime.date):
950        return datetime.datetime(year=value.year, month=value.month, day=value.day)
951    try:
952        return datetime.datetime.fromisoformat(value)
953    except ValueError:
954        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
957def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
958    if not value:
959        return None
960    if to.is_type(exp.DataType.Type.DATE):
961        return cast_as_date(value)
962    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
963        return cast_as_datetime(value)
964    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
967def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
968    if isinstance(cast, exp.Cast):
969        to = cast.to
970    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
971        to = exp.DataType.build(exp.DataType.Type.DATE)
972    else:
973        return None
974
975    if isinstance(cast.this, exp.Literal):
976        value: t.Any = cast.this.name
977    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
978        value = extract_date(cast.this)
979    else:
980        return None
981    return cast_value(value, to)
def extract_interval(expression):
988def extract_interval(expression):
989    n = int(expression.name)
990    unit = expression.text("unit").lower()
991
992    try:
993        return interval(unit, n)
994    except (UnsupportedUnit, ModuleNotFoundError):
995        return None
def date_literal(date):
 998def date_literal(date):
 999    return exp.cast(
1000        exp.Literal.string(date),
1001        exp.DataType.Type.DATETIME
1002        if isinstance(date, datetime.datetime)
1003        else exp.DataType.Type.DATE,
1004    )
def interval(unit: str, n: int = 1):
1007def interval(unit: str, n: int = 1):
1008    from dateutil.relativedelta import relativedelta
1009
1010    if unit == "year":
1011        return relativedelta(years=1 * n)
1012    if unit == "quarter":
1013        return relativedelta(months=3 * n)
1014    if unit == "month":
1015        return relativedelta(months=1 * n)
1016    if unit == "week":
1017        return relativedelta(weeks=1 * n)
1018    if unit == "day":
1019        return relativedelta(days=1 * n)
1020    if unit == "hour":
1021        return relativedelta(hours=1 * n)
1022    if unit == "minute":
1023        return relativedelta(minutes=1 * n)
1024    if unit == "second":
1025        return relativedelta(seconds=1 * n)
1026
1027    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
1030def date_floor(d: datetime.date, unit: str) -> datetime.date:
1031    if unit == "year":
1032        return d.replace(month=1, day=1)
1033    if unit == "quarter":
1034        if d.month <= 3:
1035            return d.replace(month=1, day=1)
1036        elif d.month <= 6:
1037            return d.replace(month=4, day=1)
1038        elif d.month <= 9:
1039            return d.replace(month=7, day=1)
1040        else:
1041            return d.replace(month=10, day=1)
1042    if unit == "month":
1043        return d.replace(month=d.month, day=1)
1044    if unit == "week":
1045        # Assuming week starts on Monday (0) and ends on Sunday (6)
1046        return d - datetime.timedelta(days=d.weekday())
1047    if unit == "day":
1048        return d
1049
1050    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1053def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1054    floor = date_floor(d, unit)
1055
1056    if floor == d:
1057        return d
1058
1059    return floor + interval(unit)
def boolean_literal(condition):
1062def boolean_literal(condition):
1063    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1092def gen(expression: t.Any) -> str:
1093    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1094
1095    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1096    generator is expensive so we have a bare minimum sql generator here.
1097    """
1098    if expression is None:
1099        return "_"
1100    if is_iterable(expression):
1101        return ",".join(gen(e) for e in expression)
1102    if not isinstance(expression, exp.Expression):
1103        return str(expression)
1104
1105    etype = type(expression)
1106    if etype in GEN_MAP:
1107        return GEN_MAP[etype](expression)
1108    return f"{expression.key} {gen(expression.args.values())}"

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

GEN_MAP = {<class 'sqlglot.expressions.Add'>: <function <lambda>>, <class 'sqlglot.expressions.And'>: <function <lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function <lambda>>, <class 'sqlglot.expressions.Between'>: <function <lambda>>, <class 'sqlglot.expressions.Boolean'>: <function <lambda>>, <class 'sqlglot.expressions.Bracket'>: <function <lambda>>, <class 'sqlglot.expressions.Column'>: <function <lambda>>, <class 'sqlglot.expressions.DataType'>: <function <lambda>>, <class 'sqlglot.expressions.Div'>: <function <lambda>>, <class 'sqlglot.expressions.Dot'>: <function <lambda>>, <class 'sqlglot.expressions.DPipe'>: <function <lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.Identifier'>: <function <lambda>>, <class 'sqlglot.expressions.ILike'>: <function <lambda>>, <class 'sqlglot.expressions.In'>: <function <lambda>>, <class 'sqlglot.expressions.Is'>: <function <lambda>>, <class 'sqlglot.expressions.Like'>: <function <lambda>>, <class 'sqlglot.expressions.Literal'>: <function <lambda>>, <class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.Mod'>: <function <lambda>>, <class 'sqlglot.expressions.Mul'>: <function <lambda>>, <class 'sqlglot.expressions.Neg'>: <function <lambda>>, <class 'sqlglot.expressions.NEQ'>: <function <lambda>>, <class 'sqlglot.expressions.Not'>: <function <lambda>>, <class 'sqlglot.expressions.Null'>: <function <lambda>>, <class 'sqlglot.expressions.Or'>: <function <lambda>>, <class 'sqlglot.expressions.Paren'>: <function <lambda>>, <class 'sqlglot.expressions.Sub'>: <function <lambda>>, <class 'sqlglot.expressions.Subquery'>: <function <lambda>>, <class 'sqlglot.expressions.Table'>: <function <lambda>>, <class 'sqlglot.expressions.Var'>: <function <lambda>>}