Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25
  26class UnsupportedUnit(Exception):
  27    pass
  28
  29
  30def simplify(
  31    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  32):
  33    """
  34    Rewrite sqlglot AST to simplify expressions.
  35
  36    Example:
  37        >>> import sqlglot
  38        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  39        >>> simplify(expression).sql()
  40        'TRUE'
  41
  42    Args:
  43        expression (sqlglot.Expression): expression to simplify
  44        constant_propagation: whether or not the constant propagation rule should be used
  45
  46    Returns:
  47        sqlglot.Expression: simplified expression
  48    """
  49
  50    dialect = Dialect.get_or_raise(dialect)
  51
  52    def _simplify(expression, root=True):
  53        if expression.meta.get(FINAL):
  54            return expression
  55
  56        # group by expressions cannot be simplified, for example
  57        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  58        # the projection must exactly match the group by key
  59        group = expression.args.get("group")
  60
  61        if group and hasattr(expression, "selects"):
  62            groups = set(group.expressions)
  63            group.meta[FINAL] = True
  64
  65            for e in expression.selects:
  66                for node, *_ in e.walk():
  67                    if node in groups:
  68                        e.meta[FINAL] = True
  69                        break
  70
  71            having = expression.args.get("having")
  72            if having:
  73                for node, *_ in having.walk():
  74                    if node in groups:
  75                        having.meta[FINAL] = True
  76                        break
  77
  78        # Pre-order transformations
  79        node = expression
  80        node = rewrite_between(node)
  81        node = uniq_sort(node, root)
  82        node = absorb_and_eliminate(node, root)
  83        node = simplify_concat(node)
  84        node = simplify_conditionals(node)
  85
  86        if constant_propagation:
  87            node = propagate_constants(node, root)
  88
  89        exp.replace_children(node, lambda e: _simplify(e, False))
  90
  91        # Post-order transformations
  92        node = simplify_not(node)
  93        node = flatten(node)
  94        node = simplify_connectors(node, root)
  95        node = remove_complements(node, root)
  96        node = simplify_coalesce(node)
  97        node.parent = expression.parent
  98        node = simplify_literals(node, root)
  99        node = simplify_equality(node)
 100        node = simplify_parens(node)
 101        node = simplify_datetrunc(node, dialect)
 102        node = sort_comparison(node)
 103
 104        if root:
 105            expression.replace(node)
 106
 107        return node
 108
 109    expression = while_changing(expression, _simplify)
 110    remove_where_true(expression)
 111    return expression
 112
 113
 114def catch(*exceptions):
 115    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 116
 117    def decorator(func):
 118        def wrapped(expression, *args, **kwargs):
 119            try:
 120                return func(expression, *args, **kwargs)
 121            except exceptions:
 122                return expression
 123
 124        return wrapped
 125
 126    return decorator
 127
 128
 129def rewrite_between(expression: exp.Expression) -> exp.Expression:
 130    """Rewrite x between y and z to x >= y AND x <= z.
 131
 132    This is done because comparison simplification is only done on lt/lte/gt/gte.
 133    """
 134    if isinstance(expression, exp.Between):
 135        negate = isinstance(expression.parent, exp.Not)
 136
 137        expression = exp.and_(
 138            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 139            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 140            copy=False,
 141        )
 142
 143        if negate:
 144            expression = exp.paren(expression, copy=False)
 145
 146    return expression
 147
 148
 149COMPLEMENT_COMPARISONS = {
 150    exp.LT: exp.GTE,
 151    exp.GT: exp.LTE,
 152    exp.LTE: exp.GT,
 153    exp.GTE: exp.LT,
 154    exp.EQ: exp.NEQ,
 155    exp.NEQ: exp.EQ,
 156}
 157
 158
 159def simplify_not(expression):
 160    """
 161    Demorgan's Law
 162    NOT (x OR y) -> NOT x AND NOT y
 163    NOT (x AND y) -> NOT x OR NOT y
 164    """
 165    if isinstance(expression, exp.Not):
 166        this = expression.this
 167        if is_null(this):
 168            return exp.null()
 169        if this.__class__ in COMPLEMENT_COMPARISONS:
 170            return COMPLEMENT_COMPARISONS[this.__class__](
 171                this=this.this, expression=this.expression
 172            )
 173        if isinstance(this, exp.Paren):
 174            condition = this.unnest()
 175            if isinstance(condition, exp.And):
 176                return exp.or_(
 177                    exp.not_(condition.left, copy=False),
 178                    exp.not_(condition.right, copy=False),
 179                    copy=False,
 180                )
 181            if isinstance(condition, exp.Or):
 182                return exp.and_(
 183                    exp.not_(condition.left, copy=False),
 184                    exp.not_(condition.right, copy=False),
 185                    copy=False,
 186                )
 187            if is_null(condition):
 188                return exp.null()
 189        if always_true(this):
 190            return exp.false()
 191        if is_false(this):
 192            return exp.true()
 193        if isinstance(this, exp.Not):
 194            # double negation
 195            # NOT NOT x -> x
 196            return this.this
 197    return expression
 198
 199
 200def flatten(expression):
 201    """
 202    A AND (B AND C) -> A AND B AND C
 203    A OR (B OR C) -> A OR B OR C
 204    """
 205    if isinstance(expression, exp.Connector):
 206        for node in expression.args.values():
 207            child = node.unnest()
 208            if isinstance(child, expression.__class__):
 209                node.replace(child)
 210    return expression
 211
 212
 213def simplify_connectors(expression, root=True):
 214    def _simplify_connectors(expression, left, right):
 215        if left == right:
 216            return left
 217        if isinstance(expression, exp.And):
 218            if is_false(left) or is_false(right):
 219                return exp.false()
 220            if is_null(left) or is_null(right):
 221                return exp.null()
 222            if always_true(left) and always_true(right):
 223                return exp.true()
 224            if always_true(left):
 225                return right
 226            if always_true(right):
 227                return left
 228            return _simplify_comparison(expression, left, right)
 229        elif isinstance(expression, exp.Or):
 230            if always_true(left) or always_true(right):
 231                return exp.true()
 232            if is_false(left) and is_false(right):
 233                return exp.false()
 234            if (
 235                (is_null(left) and is_null(right))
 236                or (is_null(left) and is_false(right))
 237                or (is_false(left) and is_null(right))
 238            ):
 239                return exp.null()
 240            if is_false(left):
 241                return right
 242            if is_false(right):
 243                return left
 244            return _simplify_comparison(expression, left, right, or_=True)
 245
 246    if isinstance(expression, exp.Connector):
 247        return _flat_simplify(expression, _simplify_connectors, root)
 248    return expression
 249
 250
 251LT_LTE = (exp.LT, exp.LTE)
 252GT_GTE = (exp.GT, exp.GTE)
 253
 254COMPARISONS = (
 255    *LT_LTE,
 256    *GT_GTE,
 257    exp.EQ,
 258    exp.NEQ,
 259    exp.Is,
 260)
 261
 262INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 263    exp.LT: exp.GT,
 264    exp.GT: exp.LT,
 265    exp.LTE: exp.GTE,
 266    exp.GTE: exp.LTE,
 267}
 268
 269NONDETERMINISTIC = (exp.Rand, exp.Randn)
 270
 271
 272def _simplify_comparison(expression, left, right, or_=False):
 273    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 274        ll, lr = left.args.values()
 275        rl, rr = right.args.values()
 276
 277        largs = {ll, lr}
 278        rargs = {rl, rr}
 279
 280        matching = largs & rargs
 281        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 282
 283        if matching and columns:
 284            try:
 285                l = first(largs - columns)
 286                r = first(rargs - columns)
 287            except StopIteration:
 288                return expression
 289
 290            if l.is_number and r.is_number:
 291                l = float(l.name)
 292                r = float(r.name)
 293            elif l.is_string and r.is_string:
 294                l = l.name
 295                r = r.name
 296            else:
 297                l = extract_date(l)
 298                if not l:
 299                    return None
 300                r = extract_date(r)
 301                if not r:
 302                    return None
 303
 304            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 305                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 306                    return left if (av > bv if or_ else av <= bv) else right
 307                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 308                    return left if (av < bv if or_ else av >= bv) else right
 309
 310                # we can't ever shortcut to true because the column could be null
 311                if not or_:
 312                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 313                        if av <= bv:
 314                            return exp.false()
 315                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 316                        if av >= bv:
 317                            return exp.false()
 318                    elif isinstance(a, exp.EQ):
 319                        if isinstance(b, exp.LT):
 320                            return exp.false() if av >= bv else a
 321                        if isinstance(b, exp.LTE):
 322                            return exp.false() if av > bv else a
 323                        if isinstance(b, exp.GT):
 324                            return exp.false() if av <= bv else a
 325                        if isinstance(b, exp.GTE):
 326                            return exp.false() if av < bv else a
 327                        if isinstance(b, exp.NEQ):
 328                            return exp.false() if av == bv else a
 329    return None
 330
 331
 332def remove_complements(expression, root=True):
 333    """
 334    Removing complements.
 335
 336    A AND NOT A -> FALSE
 337    A OR NOT A -> TRUE
 338    """
 339    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 340        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 341
 342        for a, b in itertools.permutations(expression.flatten(), 2):
 343            if is_complement(a, b):
 344                return complement
 345    return expression
 346
 347
 348def uniq_sort(expression, root=True):
 349    """
 350    Uniq and sort a connector.
 351
 352    C AND A AND B AND B -> A AND B AND C
 353    """
 354    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 355        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 356        flattened = tuple(expression.flatten())
 357        deduped = {gen(e): e for e in flattened}
 358        arr = tuple(deduped.items())
 359
 360        # check if the operands are already sorted, if not sort them
 361        # A AND C AND B -> A AND B AND C
 362        for i, (sql, e) in enumerate(arr[1:]):
 363            if sql < arr[i][0]:
 364                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 365                break
 366        else:
 367            # we didn't have to sort but maybe we need to dedup
 368            if len(deduped) < len(flattened):
 369                expression = result_func(*deduped.values(), copy=False)
 370
 371    return expression
 372
 373
 374def absorb_and_eliminate(expression, root=True):
 375    """
 376    absorption:
 377        A AND (A OR B) -> A
 378        A OR (A AND B) -> A
 379        A AND (NOT A OR B) -> A AND B
 380        A OR (NOT A AND B) -> A OR B
 381    elimination:
 382        (A AND B) OR (A AND NOT B) -> A
 383        (A OR B) AND (A OR NOT B) -> A
 384    """
 385    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 386        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 387
 388        for a, b in itertools.permutations(expression.flatten(), 2):
 389            if isinstance(a, kind):
 390                aa, ab = a.unnest_operands()
 391
 392                # absorb
 393                if is_complement(b, aa):
 394                    aa.replace(exp.true() if kind == exp.And else exp.false())
 395                elif is_complement(b, ab):
 396                    ab.replace(exp.true() if kind == exp.And else exp.false())
 397                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 398                    a.replace(exp.false() if kind == exp.And else exp.true())
 399                elif isinstance(b, kind):
 400                    # eliminate
 401                    rhs = b.unnest_operands()
 402                    ba, bb = rhs
 403
 404                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 405                        a.replace(aa)
 406                        b.replace(aa)
 407                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 408                        a.replace(ab)
 409                        b.replace(ab)
 410
 411    return expression
 412
 413
 414def propagate_constants(expression, root=True):
 415    """
 416    Propagate constants for conjunctions in DNF:
 417
 418    SELECT * FROM t WHERE a = b AND b = 5 becomes
 419    SELECT * FROM t WHERE a = 5 AND b = 5
 420
 421    Reference: https://www.sqlite.org/optoverview.html
 422    """
 423
 424    if (
 425        isinstance(expression, exp.And)
 426        and (root or not expression.same_parent)
 427        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 428    ):
 429        constant_mapping = {}
 430        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 431            if isinstance(expr, exp.EQ):
 432                l, r = expr.left, expr.right
 433
 434                # TODO: create a helper that can be used to detect nested literal expressions such
 435                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 436                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 437                    constant_mapping[l] = (id(l), r)
 438
 439        if constant_mapping:
 440            for column in find_all_in_scope(expression, exp.Column):
 441                parent = column.parent
 442                column_id, constant = constant_mapping.get(column) or (None, None)
 443                if (
 444                    column_id is not None
 445                    and id(column) != column_id
 446                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 447                ):
 448                    column.replace(constant.copy())
 449
 450    return expression
 451
 452
 453INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 454    exp.DateAdd: exp.Sub,
 455    exp.DateSub: exp.Add,
 456    exp.DatetimeAdd: exp.Sub,
 457    exp.DatetimeSub: exp.Add,
 458}
 459
 460INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 461    **INVERSE_DATE_OPS,
 462    exp.Add: exp.Sub,
 463    exp.Sub: exp.Add,
 464}
 465
 466
 467def _is_number(expression: exp.Expression) -> bool:
 468    return expression.is_number
 469
 470
 471def _is_interval(expression: exp.Expression) -> bool:
 472    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 473
 474
 475@catch(ModuleNotFoundError, UnsupportedUnit)
 476def simplify_equality(expression: exp.Expression) -> exp.Expression:
 477    """
 478    Use the subtraction and addition properties of equality to simplify expressions:
 479
 480        x + 1 = 3 becomes x = 2
 481
 482    There are two binary operations in the above expression: + and =
 483    Here's how we reference all the operands in the code below:
 484
 485          l     r
 486        x + 1 = 3
 487        a   b
 488    """
 489    if isinstance(expression, COMPARISONS):
 490        l, r = expression.left, expression.right
 491
 492        if not l.__class__ in INVERSE_OPS:
 493            return expression
 494
 495        if r.is_number:
 496            a_predicate = _is_number
 497            b_predicate = _is_number
 498        elif _is_date_literal(r):
 499            a_predicate = _is_date_literal
 500            b_predicate = _is_interval
 501        else:
 502            return expression
 503
 504        if l.__class__ in INVERSE_DATE_OPS:
 505            l = t.cast(exp.IntervalOp, l)
 506            a = l.this
 507            b = l.interval()
 508        else:
 509            l = t.cast(exp.Binary, l)
 510            a, b = l.left, l.right
 511
 512        if not a_predicate(a) and b_predicate(b):
 513            pass
 514        elif not a_predicate(b) and b_predicate(a):
 515            a, b = b, a
 516        else:
 517            return expression
 518
 519        return expression.__class__(
 520            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 521        )
 522    return expression
 523
 524
 525def simplify_literals(expression, root=True):
 526    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 527        return _flat_simplify(expression, _simplify_binary, root)
 528
 529    if isinstance(expression, exp.Neg):
 530        this = expression.this
 531        if this.is_number:
 532            value = this.name
 533            if value[0] == "-":
 534                return exp.Literal.number(value[1:])
 535            return exp.Literal.number(f"-{value}")
 536
 537    if type(expression) in INVERSE_DATE_OPS:
 538        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 539
 540    return expression
 541
 542
 543def _simplify_binary(expression, a, b):
 544    if isinstance(expression, exp.Is):
 545        if isinstance(b, exp.Not):
 546            c = b.this
 547            not_ = True
 548        else:
 549            c = b
 550            not_ = False
 551
 552        if is_null(c):
 553            if isinstance(a, exp.Literal):
 554                return exp.true() if not_ else exp.false()
 555            if is_null(a):
 556                return exp.false() if not_ else exp.true()
 557    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 558        return None
 559    elif is_null(a) or is_null(b):
 560        return exp.null()
 561
 562    if a.is_number and b.is_number:
 563        num_a = int(a.name) if a.is_int else Decimal(a.name)
 564        num_b = int(b.name) if b.is_int else Decimal(b.name)
 565
 566        if isinstance(expression, exp.Add):
 567            return exp.Literal.number(num_a + num_b)
 568        if isinstance(expression, exp.Mul):
 569            return exp.Literal.number(num_a * num_b)
 570
 571        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 572        if isinstance(expression, exp.Sub):
 573            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 574        if isinstance(expression, exp.Div):
 575            # engines have differing int div behavior so intdiv is not safe
 576            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 577                return None
 578            return exp.Literal.number(num_a / num_b)
 579
 580        boolean = eval_boolean(expression, num_a, num_b)
 581
 582        if boolean:
 583            return boolean
 584    elif a.is_string and b.is_string:
 585        boolean = eval_boolean(expression, a.this, b.this)
 586
 587        if boolean:
 588            return boolean
 589    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 590        a, b = extract_date(a), extract_interval(b)
 591        if a and b:
 592            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 593                return date_literal(a + b)
 594            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 595                return date_literal(a - b)
 596    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 597        a, b = extract_interval(a), extract_date(b)
 598        # you cannot subtract a date from an interval
 599        if a and b and isinstance(expression, exp.Add):
 600            return date_literal(a + b)
 601    elif _is_date_literal(a) and _is_date_literal(b):
 602        if isinstance(expression, exp.Predicate):
 603            a, b = extract_date(a), extract_date(b)
 604            boolean = eval_boolean(expression, a, b)
 605            if boolean:
 606                return boolean
 607
 608    return None
 609
 610
 611def simplify_parens(expression):
 612    if not isinstance(expression, exp.Paren):
 613        return expression
 614
 615    this = expression.this
 616    parent = expression.parent
 617
 618    if not isinstance(this, exp.Select) and (
 619        not isinstance(parent, (exp.Condition, exp.Binary))
 620        or isinstance(parent, exp.Paren)
 621        or not isinstance(this, exp.Binary)
 622        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 623        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 624        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 625        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 626    ):
 627        return this
 628    return expression
 629
 630
 631NONNULL_CONSTANTS = (
 632    exp.Literal,
 633    exp.Boolean,
 634)
 635
 636CONSTANTS = (
 637    exp.Literal,
 638    exp.Boolean,
 639    exp.Null,
 640)
 641
 642
 643def _is_nonnull_constant(expression: exp.Expression) -> bool:
 644    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 645
 646
 647def _is_constant(expression: exp.Expression) -> bool:
 648    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 649
 650
 651def simplify_coalesce(expression):
 652    # COALESCE(x) -> x
 653    if (
 654        isinstance(expression, exp.Coalesce)
 655        and (not expression.expressions or _is_nonnull_constant(expression.this))
 656        # COALESCE is also used as a Spark partitioning hint
 657        and not isinstance(expression.parent, exp.Hint)
 658    ):
 659        return expression.this
 660
 661    if not isinstance(expression, COMPARISONS):
 662        return expression
 663
 664    if isinstance(expression.left, exp.Coalesce):
 665        coalesce = expression.left
 666        other = expression.right
 667    elif isinstance(expression.right, exp.Coalesce):
 668        coalesce = expression.right
 669        other = expression.left
 670    else:
 671        return expression
 672
 673    # This transformation is valid for non-constants,
 674    # but it really only does anything if they are both constants.
 675    if not _is_constant(other):
 676        return expression
 677
 678    # Find the first constant arg
 679    for arg_index, arg in enumerate(coalesce.expressions):
 680        if _is_constant(arg):
 681            break
 682    else:
 683        return expression
 684
 685    coalesce.set("expressions", coalesce.expressions[:arg_index])
 686
 687    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 688    # since we already remove COALESCE at the top of this function.
 689    coalesce = coalesce if coalesce.expressions else coalesce.this
 690
 691    # This expression is more complex than when we started, but it will get simplified further
 692    return exp.paren(
 693        exp.or_(
 694            exp.and_(
 695                coalesce.is_(exp.null()).not_(copy=False),
 696                expression.copy(),
 697                copy=False,
 698            ),
 699            exp.and_(
 700                coalesce.is_(exp.null()),
 701                type(expression)(this=arg.copy(), expression=other.copy()),
 702                copy=False,
 703            ),
 704            copy=False,
 705        )
 706    )
 707
 708
 709CONCATS = (exp.Concat, exp.DPipe)
 710
 711
 712def simplify_concat(expression):
 713    """Reduces all groups that contain string literals by concatenating them."""
 714    if not isinstance(expression, CONCATS) or (
 715        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 716        isinstance(expression, exp.ConcatWs)
 717        and not expression.expressions[0].is_string
 718    ):
 719        return expression
 720
 721    if isinstance(expression, exp.ConcatWs):
 722        sep_expr, *expressions = expression.expressions
 723        sep = sep_expr.name
 724        concat_type = exp.ConcatWs
 725        args = {}
 726    else:
 727        expressions = expression.expressions
 728        sep = ""
 729        concat_type = exp.Concat
 730        args = {
 731            "safe": expression.args.get("safe"),
 732            "coalesce": expression.args.get("coalesce"),
 733        }
 734
 735    new_args = []
 736    for is_string_group, group in itertools.groupby(
 737        expressions or expression.flatten(), lambda e: e.is_string
 738    ):
 739        if is_string_group:
 740            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 741        else:
 742            new_args.extend(group)
 743
 744    if len(new_args) == 1 and new_args[0].is_string:
 745        return new_args[0]
 746
 747    if concat_type is exp.ConcatWs:
 748        new_args = [sep_expr] + new_args
 749
 750    return concat_type(expressions=new_args, **args)
 751
 752
 753def simplify_conditionals(expression):
 754    """Simplifies expressions like IF, CASE if their condition is statically known."""
 755    if isinstance(expression, exp.Case):
 756        this = expression.this
 757        for case in expression.args["ifs"]:
 758            cond = case.this
 759            if this:
 760                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 761                cond = cond.replace(this.pop().eq(cond))
 762
 763            if always_true(cond):
 764                return case.args["true"]
 765
 766            if always_false(cond):
 767                case.pop()
 768                if not expression.args["ifs"]:
 769                    return expression.args.get("default") or exp.null()
 770    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 771        if always_true(expression.this):
 772            return expression.args["true"]
 773        if always_false(expression.this):
 774            return expression.args.get("false") or exp.null()
 775
 776    return expression
 777
 778
 779DateRange = t.Tuple[datetime.date, datetime.date]
 780
 781
 782def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 783    """
 784    Get the date range for a DATE_TRUNC equality comparison:
 785
 786    Example:
 787        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 788    Returns:
 789        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 790    """
 791    floor = date_floor(date, unit, dialect)
 792
 793    if date != floor:
 794        # This will always be False, except for NULL values.
 795        return None
 796
 797    return floor, floor + interval(unit)
 798
 799
 800def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 801    """Get the logical expression for a date range"""
 802    return exp.and_(
 803        left >= date_literal(drange[0]),
 804        left < date_literal(drange[1]),
 805        copy=False,
 806    )
 807
 808
 809def _datetrunc_eq(
 810    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 811) -> t.Optional[exp.Expression]:
 812    drange = _datetrunc_range(date, unit, dialect)
 813    if not drange:
 814        return None
 815
 816    return _datetrunc_eq_expression(left, drange)
 817
 818
 819def _datetrunc_neq(
 820    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 821) -> t.Optional[exp.Expression]:
 822    drange = _datetrunc_range(date, unit, dialect)
 823    if not drange:
 824        return None
 825
 826    return exp.and_(
 827        left < date_literal(drange[0]),
 828        left >= date_literal(drange[1]),
 829        copy=False,
 830    )
 831
 832
 833DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 834    exp.LT: lambda l, dt, u, d: l
 835    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
 836    exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
 837    exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
 838    exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
 839    exp.EQ: _datetrunc_eq,
 840    exp.NEQ: _datetrunc_neq,
 841}
 842DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 843DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 844
 845
 846def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 847    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 848
 849
 850@catch(ModuleNotFoundError, UnsupportedUnit)
 851def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 852    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 853    comparison = expression.__class__
 854
 855    if isinstance(expression, DATETRUNCS):
 856        date = extract_date(expression.this)
 857        if date and expression.unit:
 858            return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
 859    elif comparison not in DATETRUNC_COMPARISONS:
 860        return expression
 861
 862    if isinstance(expression, exp.Binary):
 863        l, r = expression.left, expression.right
 864
 865        if not _is_datetrunc_predicate(l, r):
 866            return expression
 867
 868        l = t.cast(exp.DateTrunc, l)
 869        unit = l.unit.name.lower()
 870        date = extract_date(r)
 871
 872        if not date:
 873            return expression
 874
 875        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
 876    elif isinstance(expression, exp.In):
 877        l = expression.this
 878        rs = expression.expressions
 879
 880        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 881            l = t.cast(exp.DateTrunc, l)
 882            unit = l.unit.name.lower()
 883
 884            ranges = []
 885            for r in rs:
 886                date = extract_date(r)
 887                if not date:
 888                    return expression
 889                drange = _datetrunc_range(date, unit, dialect)
 890                if drange:
 891                    ranges.append(drange)
 892
 893            if not ranges:
 894                return expression
 895
 896            ranges = merge_ranges(ranges)
 897
 898            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 899
 900    return expression
 901
 902
 903def sort_comparison(expression: exp.Expression) -> exp.Expression:
 904    if expression.__class__ in COMPLEMENT_COMPARISONS:
 905        l, r = expression.this, expression.expression
 906        l_column = isinstance(l, exp.Column)
 907        r_column = isinstance(r, exp.Column)
 908        l_const = _is_constant(l)
 909        r_const = _is_constant(r)
 910
 911        if (l_column and not r_column) or (r_const and not l_const):
 912            return expression
 913        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 914            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 915                this=r, expression=l
 916            )
 917    return expression
 918
 919
 920# CROSS joins result in an empty table if the right table is empty.
 921# So we can only simplify certain types of joins to CROSS.
 922# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 923JOINS = {
 924    ("", ""),
 925    ("", "INNER"),
 926    ("RIGHT", ""),
 927    ("RIGHT", "OUTER"),
 928}
 929
 930
 931def remove_where_true(expression):
 932    for where in expression.find_all(exp.Where):
 933        if always_true(where.this):
 934            where.parent.set("where", None)
 935    for join in expression.find_all(exp.Join):
 936        if (
 937            always_true(join.args.get("on"))
 938            and not join.args.get("using")
 939            and not join.args.get("method")
 940            and (join.side, join.kind) in JOINS
 941        ):
 942            join.set("on", None)
 943            join.set("side", None)
 944            join.set("kind", "CROSS")
 945
 946
 947def always_true(expression):
 948    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 949        expression, exp.Literal
 950    )
 951
 952
 953def always_false(expression):
 954    return is_false(expression) or is_null(expression)
 955
 956
 957def is_complement(a, b):
 958    return isinstance(b, exp.Not) and b.this == a
 959
 960
 961def is_false(a: exp.Expression) -> bool:
 962    return type(a) is exp.Boolean and not a.this
 963
 964
 965def is_null(a: exp.Expression) -> bool:
 966    return type(a) is exp.Null
 967
 968
 969def eval_boolean(expression, a, b):
 970    if isinstance(expression, (exp.EQ, exp.Is)):
 971        return boolean_literal(a == b)
 972    if isinstance(expression, exp.NEQ):
 973        return boolean_literal(a != b)
 974    if isinstance(expression, exp.GT):
 975        return boolean_literal(a > b)
 976    if isinstance(expression, exp.GTE):
 977        return boolean_literal(a >= b)
 978    if isinstance(expression, exp.LT):
 979        return boolean_literal(a < b)
 980    if isinstance(expression, exp.LTE):
 981        return boolean_literal(a <= b)
 982    return None
 983
 984
 985def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 986    if isinstance(value, datetime.datetime):
 987        return value.date()
 988    if isinstance(value, datetime.date):
 989        return value
 990    try:
 991        return datetime.datetime.fromisoformat(value).date()
 992    except ValueError:
 993        return None
 994
 995
 996def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 997    if isinstance(value, datetime.datetime):
 998        return value
 999    if isinstance(value, datetime.date):
1000        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1001    try:
1002        return datetime.datetime.fromisoformat(value)
1003    except ValueError:
1004        return None
1005
1006
1007def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1008    if not value:
1009        return None
1010    if to.is_type(exp.DataType.Type.DATE):
1011        return cast_as_date(value)
1012    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1013        return cast_as_datetime(value)
1014    return None
1015
1016
1017def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1018    if isinstance(cast, exp.Cast):
1019        to = cast.to
1020    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1021        to = exp.DataType.build(exp.DataType.Type.DATE)
1022    else:
1023        return None
1024
1025    if isinstance(cast.this, exp.Literal):
1026        value: t.Any = cast.this.name
1027    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1028        value = extract_date(cast.this)
1029    else:
1030        return None
1031    return cast_value(value, to)
1032
1033
1034def _is_date_literal(expression: exp.Expression) -> bool:
1035    return extract_date(expression) is not None
1036
1037
1038def extract_interval(expression):
1039    try:
1040        n = int(expression.name)
1041        unit = expression.text("unit").lower()
1042        return interval(unit, n)
1043    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1044        return None
1045
1046
1047def date_literal(date):
1048    return exp.cast(
1049        exp.Literal.string(date),
1050        exp.DataType.Type.DATETIME
1051        if isinstance(date, datetime.datetime)
1052        else exp.DataType.Type.DATE,
1053    )
1054
1055
1056def interval(unit: str, n: int = 1):
1057    from dateutil.relativedelta import relativedelta
1058
1059    if unit == "year":
1060        return relativedelta(years=1 * n)
1061    if unit == "quarter":
1062        return relativedelta(months=3 * n)
1063    if unit == "month":
1064        return relativedelta(months=1 * n)
1065    if unit == "week":
1066        return relativedelta(weeks=1 * n)
1067    if unit == "day":
1068        return relativedelta(days=1 * n)
1069    if unit == "hour":
1070        return relativedelta(hours=1 * n)
1071    if unit == "minute":
1072        return relativedelta(minutes=1 * n)
1073    if unit == "second":
1074        return relativedelta(seconds=1 * n)
1075
1076    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1077
1078
1079def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1080    if unit == "year":
1081        return d.replace(month=1, day=1)
1082    if unit == "quarter":
1083        if d.month <= 3:
1084            return d.replace(month=1, day=1)
1085        elif d.month <= 6:
1086            return d.replace(month=4, day=1)
1087        elif d.month <= 9:
1088            return d.replace(month=7, day=1)
1089        else:
1090            return d.replace(month=10, day=1)
1091    if unit == "month":
1092        return d.replace(month=d.month, day=1)
1093    if unit == "week":
1094        # Assuming week starts on Monday (0) and ends on Sunday (6)
1095        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1096    if unit == "day":
1097        return d
1098
1099    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1100
1101
1102def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1103    floor = date_floor(d, unit, dialect)
1104
1105    if floor == d:
1106        return d
1107
1108    return floor + interval(unit)
1109
1110
1111def boolean_literal(condition):
1112    return exp.true() if condition else exp.false()
1113
1114
1115def _flat_simplify(expression, simplifier, root=True):
1116    if root or not expression.same_parent:
1117        operands = []
1118        queue = deque(expression.flatten(unnest=False))
1119        size = len(queue)
1120
1121        while queue:
1122            a = queue.popleft()
1123
1124            for b in queue:
1125                result = simplifier(expression, a, b)
1126
1127                if result and result is not expression:
1128                    queue.remove(b)
1129                    queue.appendleft(result)
1130                    break
1131            else:
1132                operands.append(a)
1133
1134        if len(operands) < size:
1135            return functools.reduce(
1136                lambda a, b: expression.__class__(this=a, expression=b), operands
1137            )
1138    return expression
1139
1140
1141def gen(expression: t.Any) -> str:
1142    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1143
1144    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1145    generator is expensive so we have a bare minimum sql generator here.
1146    """
1147    if expression is None:
1148        return "_"
1149    if is_iterable(expression):
1150        return ",".join(gen(e) for e in expression)
1151    if not isinstance(expression, exp.Expression):
1152        return str(expression)
1153
1154    etype = type(expression)
1155    if etype in GEN_MAP:
1156        return GEN_MAP[etype](expression)
1157    return f"{expression.key} {gen(expression.args.values())}"
1158
1159
1160GEN_MAP = {
1161    exp.Add: lambda e: _binary(e, "+"),
1162    exp.And: lambda e: _binary(e, "AND"),
1163    exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
1164    exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
1165    exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
1166    exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
1167    exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
1168    exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
1169    exp.Div: lambda e: _binary(e, "/"),
1170    exp.Dot: lambda e: _binary(e, "."),
1171    exp.EQ: lambda e: _binary(e, "="),
1172    exp.GT: lambda e: _binary(e, ">"),
1173    exp.GTE: lambda e: _binary(e, ">="),
1174    exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
1175    exp.ILike: lambda e: _binary(e, "ILIKE"),
1176    exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
1177    exp.Is: lambda e: _binary(e, "IS"),
1178    exp.Like: lambda e: _binary(e, "LIKE"),
1179    exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
1180    exp.LT: lambda e: _binary(e, "<"),
1181    exp.LTE: lambda e: _binary(e, "<="),
1182    exp.Mod: lambda e: _binary(e, "%"),
1183    exp.Mul: lambda e: _binary(e, "*"),
1184    exp.Neg: lambda e: _unary(e, "-"),
1185    exp.NEQ: lambda e: _binary(e, "<>"),
1186    exp.Not: lambda e: _unary(e, "NOT"),
1187    exp.Null: lambda e: "NULL",
1188    exp.Or: lambda e: _binary(e, "OR"),
1189    exp.Paren: lambda e: f"({gen(e.this)})",
1190    exp.Sub: lambda e: _binary(e, "-"),
1191    exp.Subquery: lambda e: f"({gen(e.args.values())})",
1192    exp.Table: lambda e: gen(e.args.values()),
1193    exp.Var: lambda e: e.name,
1194}
1195
1196
1197def _binary(e: exp.Binary, op: str) -> str:
1198    return f"{gen(e.left)} {op} {gen(e.right)}"
1199
1200
1201def _unary(e: exp.Unary, op: str) -> str:
1202    return f"{op} {gen(e.this)}"
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
27class UnsupportedUnit(Exception):
28    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 31def simplify(
 32    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 33):
 34    """
 35    Rewrite sqlglot AST to simplify expressions.
 36
 37    Example:
 38        >>> import sqlglot
 39        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 40        >>> simplify(expression).sql()
 41        'TRUE'
 42
 43    Args:
 44        expression (sqlglot.Expression): expression to simplify
 45        constant_propagation: whether or not the constant propagation rule should be used
 46
 47    Returns:
 48        sqlglot.Expression: simplified expression
 49    """
 50
 51    dialect = Dialect.get_or_raise(dialect)
 52
 53    def _simplify(expression, root=True):
 54        if expression.meta.get(FINAL):
 55            return expression
 56
 57        # group by expressions cannot be simplified, for example
 58        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 59        # the projection must exactly match the group by key
 60        group = expression.args.get("group")
 61
 62        if group and hasattr(expression, "selects"):
 63            groups = set(group.expressions)
 64            group.meta[FINAL] = True
 65
 66            for e in expression.selects:
 67                for node, *_ in e.walk():
 68                    if node in groups:
 69                        e.meta[FINAL] = True
 70                        break
 71
 72            having = expression.args.get("having")
 73            if having:
 74                for node, *_ in having.walk():
 75                    if node in groups:
 76                        having.meta[FINAL] = True
 77                        break
 78
 79        # Pre-order transformations
 80        node = expression
 81        node = rewrite_between(node)
 82        node = uniq_sort(node, root)
 83        node = absorb_and_eliminate(node, root)
 84        node = simplify_concat(node)
 85        node = simplify_conditionals(node)
 86
 87        if constant_propagation:
 88            node = propagate_constants(node, root)
 89
 90        exp.replace_children(node, lambda e: _simplify(e, False))
 91
 92        # Post-order transformations
 93        node = simplify_not(node)
 94        node = flatten(node)
 95        node = simplify_connectors(node, root)
 96        node = remove_complements(node, root)
 97        node = simplify_coalesce(node)
 98        node.parent = expression.parent
 99        node = simplify_literals(node, root)
100        node = simplify_equality(node)
101        node = simplify_parens(node)
102        node = simplify_datetrunc(node, dialect)
103        node = sort_comparison(node)
104
105        if root:
106            expression.replace(node)
107
108        return node
109
110    expression = while_changing(expression, _simplify)
111    remove_where_true(expression)
112    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
115def catch(*exceptions):
116    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
117
118    def decorator(func):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
124
125        return wrapped
126
127    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
130def rewrite_between(expression: exp.Expression) -> exp.Expression:
131    """Rewrite x between y and z to x >= y AND x <= z.
132
133    This is done because comparison simplification is only done on lt/lte/gt/gte.
134    """
135    if isinstance(expression, exp.Between):
136        negate = isinstance(expression.parent, exp.Not)
137
138        expression = exp.and_(
139            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
140            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
141            copy=False,
142        )
143
144        if negate:
145            expression = exp.paren(expression, copy=False)
146
147    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
160def simplify_not(expression):
161    """
162    Demorgan's Law
163    NOT (x OR y) -> NOT x AND NOT y
164    NOT (x AND y) -> NOT x OR NOT y
165    """
166    if isinstance(expression, exp.Not):
167        this = expression.this
168        if is_null(this):
169            return exp.null()
170        if this.__class__ in COMPLEMENT_COMPARISONS:
171            return COMPLEMENT_COMPARISONS[this.__class__](
172                this=this.this, expression=this.expression
173            )
174        if isinstance(this, exp.Paren):
175            condition = this.unnest()
176            if isinstance(condition, exp.And):
177                return exp.or_(
178                    exp.not_(condition.left, copy=False),
179                    exp.not_(condition.right, copy=False),
180                    copy=False,
181                )
182            if isinstance(condition, exp.Or):
183                return exp.and_(
184                    exp.not_(condition.left, copy=False),
185                    exp.not_(condition.right, copy=False),
186                    copy=False,
187                )
188            if is_null(condition):
189                return exp.null()
190        if always_true(this):
191            return exp.false()
192        if is_false(this):
193            return exp.true()
194        if isinstance(this, exp.Not):
195            # double negation
196            # NOT NOT x -> x
197            return this.this
198    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
201def flatten(expression):
202    """
203    A AND (B AND C) -> A AND B AND C
204    A OR (B OR C) -> A OR B OR C
205    """
206    if isinstance(expression, exp.Connector):
207        for node in expression.args.values():
208            child = node.unnest()
209            if isinstance(child, expression.__class__):
210                node.replace(child)
211    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
214def simplify_connectors(expression, root=True):
215    def _simplify_connectors(expression, left, right):
216        if left == right:
217            return left
218        if isinstance(expression, exp.And):
219            if is_false(left) or is_false(right):
220                return exp.false()
221            if is_null(left) or is_null(right):
222                return exp.null()
223            if always_true(left) and always_true(right):
224                return exp.true()
225            if always_true(left):
226                return right
227            if always_true(right):
228                return left
229            return _simplify_comparison(expression, left, right)
230        elif isinstance(expression, exp.Or):
231            if always_true(left) or always_true(right):
232                return exp.true()
233            if is_false(left) and is_false(right):
234                return exp.false()
235            if (
236                (is_null(left) and is_null(right))
237                or (is_null(left) and is_false(right))
238                or (is_false(left) and is_null(right))
239            ):
240                return exp.null()
241            if is_false(left):
242                return right
243            if is_false(right):
244                return left
245            return _simplify_comparison(expression, left, right, or_=True)
246
247    if isinstance(expression, exp.Connector):
248        return _flat_simplify(expression, _simplify_connectors, root)
249    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
333def remove_complements(expression, root=True):
334    """
335    Removing complements.
336
337    A AND NOT A -> FALSE
338    A OR NOT A -> TRUE
339    """
340    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
341        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
342
343        for a, b in itertools.permutations(expression.flatten(), 2):
344            if is_complement(a, b):
345                return complement
346    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
349def uniq_sort(expression, root=True):
350    """
351    Uniq and sort a connector.
352
353    C AND A AND B AND B -> A AND B AND C
354    """
355    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
356        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
357        flattened = tuple(expression.flatten())
358        deduped = {gen(e): e for e in flattened}
359        arr = tuple(deduped.items())
360
361        # check if the operands are already sorted, if not sort them
362        # A AND C AND B -> A AND B AND C
363        for i, (sql, e) in enumerate(arr[1:]):
364            if sql < arr[i][0]:
365                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
366                break
367        else:
368            # we didn't have to sort but maybe we need to dedup
369            if len(deduped) < len(flattened):
370                expression = result_func(*deduped.values(), copy=False)
371
372    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
375def absorb_and_eliminate(expression, root=True):
376    """
377    absorption:
378        A AND (A OR B) -> A
379        A OR (A AND B) -> A
380        A AND (NOT A OR B) -> A AND B
381        A OR (NOT A AND B) -> A OR B
382    elimination:
383        (A AND B) OR (A AND NOT B) -> A
384        (A OR B) AND (A OR NOT B) -> A
385    """
386    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
387        kind = exp.Or if isinstance(expression, exp.And) else exp.And
388
389        for a, b in itertools.permutations(expression.flatten(), 2):
390            if isinstance(a, kind):
391                aa, ab = a.unnest_operands()
392
393                # absorb
394                if is_complement(b, aa):
395                    aa.replace(exp.true() if kind == exp.And else exp.false())
396                elif is_complement(b, ab):
397                    ab.replace(exp.true() if kind == exp.And else exp.false())
398                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
399                    a.replace(exp.false() if kind == exp.And else exp.true())
400                elif isinstance(b, kind):
401                    # eliminate
402                    rhs = b.unnest_operands()
403                    ba, bb = rhs
404
405                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
406                        a.replace(aa)
407                        b.replace(aa)
408                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
409                        a.replace(ab)
410                        b.replace(ab)
411
412    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
415def propagate_constants(expression, root=True):
416    """
417    Propagate constants for conjunctions in DNF:
418
419    SELECT * FROM t WHERE a = b AND b = 5 becomes
420    SELECT * FROM t WHERE a = 5 AND b = 5
421
422    Reference: https://www.sqlite.org/optoverview.html
423    """
424
425    if (
426        isinstance(expression, exp.And)
427        and (root or not expression.same_parent)
428        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
429    ):
430        constant_mapping = {}
431        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
432            if isinstance(expr, exp.EQ):
433                l, r = expr.left, expr.right
434
435                # TODO: create a helper that can be used to detect nested literal expressions such
436                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
437                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
438                    constant_mapping[l] = (id(l), r)
439
440        if constant_mapping:
441            for column in find_all_in_scope(expression, exp.Column):
442                parent = column.parent
443                column_id, constant = constant_mapping.get(column) or (None, None)
444                if (
445                    column_id is not None
446                    and id(column) != column_id
447                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
448                ):
449                    column.replace(constant.copy())
450
451    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
526def simplify_literals(expression, root=True):
527    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
528        return _flat_simplify(expression, _simplify_binary, root)
529
530    if isinstance(expression, exp.Neg):
531        this = expression.this
532        if this.is_number:
533            value = this.name
534            if value[0] == "-":
535                return exp.Literal.number(value[1:])
536            return exp.Literal.number(f"-{value}")
537
538    if type(expression) in INVERSE_DATE_OPS:
539        return _simplify_binary(expression, expression.this, expression.interval()) or expression
540
541    return expression
def simplify_parens(expression):
612def simplify_parens(expression):
613    if not isinstance(expression, exp.Paren):
614        return expression
615
616    this = expression.this
617    parent = expression.parent
618
619    if not isinstance(this, exp.Select) and (
620        not isinstance(parent, (exp.Condition, exp.Binary))
621        or isinstance(parent, exp.Paren)
622        or not isinstance(this, exp.Binary)
623        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
624        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
625        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
626        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
627    ):
628        return this
629    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
652def simplify_coalesce(expression):
653    # COALESCE(x) -> x
654    if (
655        isinstance(expression, exp.Coalesce)
656        and (not expression.expressions or _is_nonnull_constant(expression.this))
657        # COALESCE is also used as a Spark partitioning hint
658        and not isinstance(expression.parent, exp.Hint)
659    ):
660        return expression.this
661
662    if not isinstance(expression, COMPARISONS):
663        return expression
664
665    if isinstance(expression.left, exp.Coalesce):
666        coalesce = expression.left
667        other = expression.right
668    elif isinstance(expression.right, exp.Coalesce):
669        coalesce = expression.right
670        other = expression.left
671    else:
672        return expression
673
674    # This transformation is valid for non-constants,
675    # but it really only does anything if they are both constants.
676    if not _is_constant(other):
677        return expression
678
679    # Find the first constant arg
680    for arg_index, arg in enumerate(coalesce.expressions):
681        if _is_constant(arg):
682            break
683    else:
684        return expression
685
686    coalesce.set("expressions", coalesce.expressions[:arg_index])
687
688    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
689    # since we already remove COALESCE at the top of this function.
690    coalesce = coalesce if coalesce.expressions else coalesce.this
691
692    # This expression is more complex than when we started, but it will get simplified further
693    return exp.paren(
694        exp.or_(
695            exp.and_(
696                coalesce.is_(exp.null()).not_(copy=False),
697                expression.copy(),
698                copy=False,
699            ),
700            exp.and_(
701                coalesce.is_(exp.null()),
702                type(expression)(this=arg.copy(), expression=other.copy()),
703                copy=False,
704            ),
705            copy=False,
706        )
707    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
713def simplify_concat(expression):
714    """Reduces all groups that contain string literals by concatenating them."""
715    if not isinstance(expression, CONCATS) or (
716        # We can't reduce a CONCAT_WS call if we don't statically know the separator
717        isinstance(expression, exp.ConcatWs)
718        and not expression.expressions[0].is_string
719    ):
720        return expression
721
722    if isinstance(expression, exp.ConcatWs):
723        sep_expr, *expressions = expression.expressions
724        sep = sep_expr.name
725        concat_type = exp.ConcatWs
726        args = {}
727    else:
728        expressions = expression.expressions
729        sep = ""
730        concat_type = exp.Concat
731        args = {
732            "safe": expression.args.get("safe"),
733            "coalesce": expression.args.get("coalesce"),
734        }
735
736    new_args = []
737    for is_string_group, group in itertools.groupby(
738        expressions or expression.flatten(), lambda e: e.is_string
739    ):
740        if is_string_group:
741            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
742        else:
743            new_args.extend(group)
744
745    if len(new_args) == 1 and new_args[0].is_string:
746        return new_args[0]
747
748    if concat_type is exp.ConcatWs:
749        new_args = [sep_expr] + new_args
750
751    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
754def simplify_conditionals(expression):
755    """Simplifies expressions like IF, CASE if their condition is statically known."""
756    if isinstance(expression, exp.Case):
757        this = expression.this
758        for case in expression.args["ifs"]:
759            cond = case.this
760            if this:
761                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
762                cond = cond.replace(this.pop().eq(cond))
763
764            if always_true(cond):
765                return case.args["true"]
766
767            if always_false(cond):
768                case.pop()
769                if not expression.args["ifs"]:
770                    return expression.args.get("default") or exp.null()
771    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
772        if always_true(expression.this):
773            return expression.args["true"]
774        if always_false(expression.this):
775            return expression.args.get("false") or exp.null()
776
777    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GTE'>}
def simplify_datetrunc(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
904def sort_comparison(expression: exp.Expression) -> exp.Expression:
905    if expression.__class__ in COMPLEMENT_COMPARISONS:
906        l, r = expression.this, expression.expression
907        l_column = isinstance(l, exp.Column)
908        r_column = isinstance(r, exp.Column)
909        l_const = _is_constant(l)
910        r_const = _is_constant(r)
911
912        if (l_column and not r_column) or (r_const and not l_const):
913            return expression
914        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
915            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
916                this=r, expression=l
917            )
918    return expression
JOINS = {('', 'INNER'), ('RIGHT', 'OUTER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
932def remove_where_true(expression):
933    for where in expression.find_all(exp.Where):
934        if always_true(where.this):
935            where.parent.set("where", None)
936    for join in expression.find_all(exp.Join):
937        if (
938            always_true(join.args.get("on"))
939            and not join.args.get("using")
940            and not join.args.get("method")
941            and (join.side, join.kind) in JOINS
942        ):
943            join.set("on", None)
944            join.set("side", None)
945            join.set("kind", "CROSS")
def always_true(expression):
948def always_true(expression):
949    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
950        expression, exp.Literal
951    )
def always_false(expression):
954def always_false(expression):
955    return is_false(expression) or is_null(expression)
def is_complement(a, b):
958def is_complement(a, b):
959    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
962def is_false(a: exp.Expression) -> bool:
963    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
966def is_null(a: exp.Expression) -> bool:
967    return type(a) is exp.Null
def eval_boolean(expression, a, b):
970def eval_boolean(expression, a, b):
971    if isinstance(expression, (exp.EQ, exp.Is)):
972        return boolean_literal(a == b)
973    if isinstance(expression, exp.NEQ):
974        return boolean_literal(a != b)
975    if isinstance(expression, exp.GT):
976        return boolean_literal(a > b)
977    if isinstance(expression, exp.GTE):
978        return boolean_literal(a >= b)
979    if isinstance(expression, exp.LT):
980        return boolean_literal(a < b)
981    if isinstance(expression, exp.LTE):
982        return boolean_literal(a <= b)
983    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
986def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
987    if isinstance(value, datetime.datetime):
988        return value.date()
989    if isinstance(value, datetime.date):
990        return value
991    try:
992        return datetime.datetime.fromisoformat(value).date()
993    except ValueError:
994        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
 997def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 998    if isinstance(value, datetime.datetime):
 999        return value
1000    if isinstance(value, datetime.date):
1001        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1002    try:
1003        return datetime.datetime.fromisoformat(value)
1004    except ValueError:
1005        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1008def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1009    if not value:
1010        return None
1011    if to.is_type(exp.DataType.Type.DATE):
1012        return cast_as_date(value)
1013    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1014        return cast_as_datetime(value)
1015    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1018def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1019    if isinstance(cast, exp.Cast):
1020        to = cast.to
1021    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1022        to = exp.DataType.build(exp.DataType.Type.DATE)
1023    else:
1024        return None
1025
1026    if isinstance(cast.this, exp.Literal):
1027        value: t.Any = cast.this.name
1028    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1029        value = extract_date(cast.this)
1030    else:
1031        return None
1032    return cast_value(value, to)
def extract_interval(expression):
1039def extract_interval(expression):
1040    try:
1041        n = int(expression.name)
1042        unit = expression.text("unit").lower()
1043        return interval(unit, n)
1044    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1045        return None
def date_literal(date):
1048def date_literal(date):
1049    return exp.cast(
1050        exp.Literal.string(date),
1051        exp.DataType.Type.DATETIME
1052        if isinstance(date, datetime.datetime)
1053        else exp.DataType.Type.DATE,
1054    )
def interval(unit: str, n: int = 1):
1057def interval(unit: str, n: int = 1):
1058    from dateutil.relativedelta import relativedelta
1059
1060    if unit == "year":
1061        return relativedelta(years=1 * n)
1062    if unit == "quarter":
1063        return relativedelta(months=3 * n)
1064    if unit == "month":
1065        return relativedelta(months=1 * n)
1066    if unit == "week":
1067        return relativedelta(weeks=1 * n)
1068    if unit == "day":
1069        return relativedelta(days=1 * n)
1070    if unit == "hour":
1071        return relativedelta(hours=1 * n)
1072    if unit == "minute":
1073        return relativedelta(minutes=1 * n)
1074    if unit == "second":
1075        return relativedelta(seconds=1 * n)
1076
1077    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1080def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1081    if unit == "year":
1082        return d.replace(month=1, day=1)
1083    if unit == "quarter":
1084        if d.month <= 3:
1085            return d.replace(month=1, day=1)
1086        elif d.month <= 6:
1087            return d.replace(month=4, day=1)
1088        elif d.month <= 9:
1089            return d.replace(month=7, day=1)
1090        else:
1091            return d.replace(month=10, day=1)
1092    if unit == "month":
1093        return d.replace(month=d.month, day=1)
1094    if unit == "week":
1095        # Assuming week starts on Monday (0) and ends on Sunday (6)
1096        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1097    if unit == "day":
1098        return d
1099
1100    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1103def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1104    floor = date_floor(d, unit, dialect)
1105
1106    if floor == d:
1107        return d
1108
1109    return floor + interval(unit)
def boolean_literal(condition):
1112def boolean_literal(condition):
1113    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1142def gen(expression: t.Any) -> str:
1143    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1144
1145    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1146    generator is expensive so we have a bare minimum sql generator here.
1147    """
1148    if expression is None:
1149        return "_"
1150    if is_iterable(expression):
1151        return ",".join(gen(e) for e in expression)
1152    if not isinstance(expression, exp.Expression):
1153        return str(expression)
1154
1155    etype = type(expression)
1156    if etype in GEN_MAP:
1157        return GEN_MAP[etype](expression)
1158    return f"{expression.key} {gen(expression.args.values())}"

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

GEN_MAP = {<class 'sqlglot.expressions.Add'>: <function <lambda>>, <class 'sqlglot.expressions.And'>: <function <lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function <lambda>>, <class 'sqlglot.expressions.Between'>: <function <lambda>>, <class 'sqlglot.expressions.Boolean'>: <function <lambda>>, <class 'sqlglot.expressions.Bracket'>: <function <lambda>>, <class 'sqlglot.expressions.Column'>: <function <lambda>>, <class 'sqlglot.expressions.DataType'>: <function <lambda>>, <class 'sqlglot.expressions.Div'>: <function <lambda>>, <class 'sqlglot.expressions.Dot'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.Identifier'>: <function <lambda>>, <class 'sqlglot.expressions.ILike'>: <function <lambda>>, <class 'sqlglot.expressions.In'>: <function <lambda>>, <class 'sqlglot.expressions.Is'>: <function <lambda>>, <class 'sqlglot.expressions.Like'>: <function <lambda>>, <class 'sqlglot.expressions.Literal'>: <function <lambda>>, <class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.Mod'>: <function <lambda>>, <class 'sqlglot.expressions.Mul'>: <function <lambda>>, <class 'sqlglot.expressions.Neg'>: <function <lambda>>, <class 'sqlglot.expressions.NEQ'>: <function <lambda>>, <class 'sqlglot.expressions.Not'>: <function <lambda>>, <class 'sqlglot.expressions.Null'>: <function <lambda>>, <class 'sqlglot.expressions.Or'>: <function <lambda>>, <class 'sqlglot.expressions.Paren'>: <function <lambda>>, <class 'sqlglot.expressions.Sub'>: <function <lambda>>, <class 'sqlglot.expressions.Subquery'>: <function <lambda>>, <class 'sqlglot.expressions.Table'>: <function <lambda>>, <class 'sqlglot.expressions.Var'>: <function <lambda>>}