Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
  11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  12
  13# Final means that an expression should not be simplified
  14FINAL = "final"
  15
  16
  17class UnsupportedUnit(Exception):
  18    pass
  19
  20
  21def simplify(expression, constant_propagation=False):
  22    """
  23    Rewrite sqlglot AST to simplify expressions.
  24
  25    Example:
  26        >>> import sqlglot
  27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  28        >>> simplify(expression).sql()
  29        'TRUE'
  30
  31    Args:
  32        expression (sqlglot.Expression): expression to simplify
  33        constant_propagation: whether or not the constant propagation rule should be used
  34
  35    Returns:
  36        sqlglot.Expression: simplified expression
  37    """
  38
  39    # group by expressions cannot be simplified, for example
  40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  41    # the projection must exactly match the group by key
  42    for group in expression.find_all(exp.Group):
  43        select = group.parent
  44        groups = set(group.expressions)
  45        group.meta[FINAL] = True
  46
  47        for e in select.selects:
  48            for node, *_ in e.walk():
  49                if node in groups:
  50                    e.meta[FINAL] = True
  51                    break
  52
  53        having = select.args.get("having")
  54        if having:
  55            for node, *_ in having.walk():
  56                if node in groups:
  57                    having.meta[FINAL] = True
  58                    break
  59
  60    def _simplify(expression, root=True):
  61        if expression.meta.get(FINAL):
  62            return expression
  63
  64        # Pre-order transformations
  65        node = expression
  66        node = rewrite_between(node)
  67        node = uniq_sort(node, root)
  68        node = absorb_and_eliminate(node, root)
  69        node = simplify_concat(node)
  70        node = simplify_conditionals(node)
  71
  72        if constant_propagation:
  73            node = propagate_constants(node, root)
  74
  75        exp.replace_children(node, lambda e: _simplify(e, False))
  76
  77        # Post-order transformations
  78        node = simplify_not(node)
  79        node = flatten(node)
  80        node = simplify_connectors(node, root)
  81        node = remove_complements(node, root)
  82        node = simplify_coalesce(node)
  83        node.parent = expression.parent
  84        node = simplify_literals(node, root)
  85        node = simplify_equality(node)
  86        node = simplify_parens(node)
  87        node = simplify_datetrunc_predicate(node)
  88
  89        if root:
  90            expression.replace(node)
  91
  92        return node
  93
  94    expression = while_changing(expression, _simplify)
  95    remove_where_true(expression)
  96    return expression
  97
  98
  99def catch(*exceptions):
 100    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 101
 102    def decorator(func):
 103        def wrapped(expression, *args, **kwargs):
 104            try:
 105                return func(expression, *args, **kwargs)
 106            except exceptions:
 107                return expression
 108
 109        return wrapped
 110
 111    return decorator
 112
 113
 114def rewrite_between(expression: exp.Expression) -> exp.Expression:
 115    """Rewrite x between y and z to x >= y AND x <= z.
 116
 117    This is done because comparison simplification is only done on lt/lte/gt/gte.
 118    """
 119    if isinstance(expression, exp.Between):
 120        return exp.and_(
 121            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 122            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 123            copy=False,
 124        )
 125    return expression
 126
 127
 128def simplify_not(expression):
 129    """
 130    Demorgan's Law
 131    NOT (x OR y) -> NOT x AND NOT y
 132    NOT (x AND y) -> NOT x OR NOT y
 133    """
 134    if isinstance(expression, exp.Not):
 135        if is_null(expression.this):
 136            return exp.null()
 137        if isinstance(expression.this, exp.Paren):
 138            condition = expression.this.unnest()
 139            if isinstance(condition, exp.And):
 140                return exp.or_(
 141                    exp.not_(condition.left, copy=False),
 142                    exp.not_(condition.right, copy=False),
 143                    copy=False,
 144                )
 145            if isinstance(condition, exp.Or):
 146                return exp.and_(
 147                    exp.not_(condition.left, copy=False),
 148                    exp.not_(condition.right, copy=False),
 149                    copy=False,
 150                )
 151            if is_null(condition):
 152                return exp.null()
 153        if always_true(expression.this):
 154            return exp.false()
 155        if is_false(expression.this):
 156            return exp.true()
 157        if isinstance(expression.this, exp.Not):
 158            # double negation
 159            # NOT NOT x -> x
 160            return expression.this.this
 161    return expression
 162
 163
 164def flatten(expression):
 165    """
 166    A AND (B AND C) -> A AND B AND C
 167    A OR (B OR C) -> A OR B OR C
 168    """
 169    if isinstance(expression, exp.Connector):
 170        for node in expression.args.values():
 171            child = node.unnest()
 172            if isinstance(child, expression.__class__):
 173                node.replace(child)
 174    return expression
 175
 176
 177def simplify_connectors(expression, root=True):
 178    def _simplify_connectors(expression, left, right):
 179        if left == right:
 180            return left
 181        if isinstance(expression, exp.And):
 182            if is_false(left) or is_false(right):
 183                return exp.false()
 184            if is_null(left) or is_null(right):
 185                return exp.null()
 186            if always_true(left) and always_true(right):
 187                return exp.true()
 188            if always_true(left):
 189                return right
 190            if always_true(right):
 191                return left
 192            return _simplify_comparison(expression, left, right)
 193        elif isinstance(expression, exp.Or):
 194            if always_true(left) or always_true(right):
 195                return exp.true()
 196            if is_false(left) and is_false(right):
 197                return exp.false()
 198            if (
 199                (is_null(left) and is_null(right))
 200                or (is_null(left) and is_false(right))
 201                or (is_false(left) and is_null(right))
 202            ):
 203                return exp.null()
 204            if is_false(left):
 205                return right
 206            if is_false(right):
 207                return left
 208            return _simplify_comparison(expression, left, right, or_=True)
 209
 210    if isinstance(expression, exp.Connector):
 211        return _flat_simplify(expression, _simplify_connectors, root)
 212    return expression
 213
 214
 215LT_LTE = (exp.LT, exp.LTE)
 216GT_GTE = (exp.GT, exp.GTE)
 217
 218COMPARISONS = (
 219    *LT_LTE,
 220    *GT_GTE,
 221    exp.EQ,
 222    exp.NEQ,
 223    exp.Is,
 224)
 225
 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 227    exp.LT: exp.GT,
 228    exp.GT: exp.LT,
 229    exp.LTE: exp.GTE,
 230    exp.GTE: exp.LTE,
 231}
 232
 233
 234def _simplify_comparison(expression, left, right, or_=False):
 235    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 236        ll, lr = left.args.values()
 237        rl, rr = right.args.values()
 238
 239        largs = {ll, lr}
 240        rargs = {rl, rr}
 241
 242        matching = largs & rargs
 243        columns = {m for m in matching if isinstance(m, exp.Column)}
 244
 245        if matching and columns:
 246            try:
 247                l = first(largs - columns)
 248                r = first(rargs - columns)
 249            except StopIteration:
 250                return expression
 251
 252            # make sure the comparison is always of the form x > 1 instead of 1 < x
 253            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 254                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 255            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 256                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 257
 258            if l.is_number and r.is_number:
 259                l = float(l.name)
 260                r = float(r.name)
 261            elif l.is_string and r.is_string:
 262                l = l.name
 263                r = r.name
 264            else:
 265                return None
 266
 267            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 268                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 269                    return left if (av > bv if or_ else av <= bv) else right
 270                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 271                    return left if (av < bv if or_ else av >= bv) else right
 272
 273                # we can't ever shortcut to true because the column could be null
 274                if not or_:
 275                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 276                        if av <= bv:
 277                            return exp.false()
 278                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 279                        if av >= bv:
 280                            return exp.false()
 281                    elif isinstance(a, exp.EQ):
 282                        if isinstance(b, exp.LT):
 283                            return exp.false() if av >= bv else a
 284                        if isinstance(b, exp.LTE):
 285                            return exp.false() if av > bv else a
 286                        if isinstance(b, exp.GT):
 287                            return exp.false() if av <= bv else a
 288                        if isinstance(b, exp.GTE):
 289                            return exp.false() if av < bv else a
 290                        if isinstance(b, exp.NEQ):
 291                            return exp.false() if av == bv else a
 292    return None
 293
 294
 295def remove_complements(expression, root=True):
 296    """
 297    Removing complements.
 298
 299    A AND NOT A -> FALSE
 300    A OR NOT A -> TRUE
 301    """
 302    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 303        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 304
 305        for a, b in itertools.permutations(expression.flatten(), 2):
 306            if is_complement(a, b):
 307                return complement
 308    return expression
 309
 310
 311def uniq_sort(expression, root=True):
 312    """
 313    Uniq and sort a connector.
 314
 315    C AND A AND B AND B -> A AND B AND C
 316    """
 317    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 318        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 319        flattened = tuple(expression.flatten())
 320        deduped = {gen(e): e for e in flattened}
 321        arr = tuple(deduped.items())
 322
 323        # check if the operands are already sorted, if not sort them
 324        # A AND C AND B -> A AND B AND C
 325        for i, (sql, e) in enumerate(arr[1:]):
 326            if sql < arr[i][0]:
 327                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 328                break
 329        else:
 330            # we didn't have to sort but maybe we need to dedup
 331            if len(deduped) < len(flattened):
 332                expression = result_func(*deduped.values(), copy=False)
 333
 334    return expression
 335
 336
 337def absorb_and_eliminate(expression, root=True):
 338    """
 339    absorption:
 340        A AND (A OR B) -> A
 341        A OR (A AND B) -> A
 342        A AND (NOT A OR B) -> A AND B
 343        A OR (NOT A AND B) -> A OR B
 344    elimination:
 345        (A AND B) OR (A AND NOT B) -> A
 346        (A OR B) AND (A OR NOT B) -> A
 347    """
 348    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 349        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 350
 351        for a, b in itertools.permutations(expression.flatten(), 2):
 352            if isinstance(a, kind):
 353                aa, ab = a.unnest_operands()
 354
 355                # absorb
 356                if is_complement(b, aa):
 357                    aa.replace(exp.true() if kind == exp.And else exp.false())
 358                elif is_complement(b, ab):
 359                    ab.replace(exp.true() if kind == exp.And else exp.false())
 360                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 361                    a.replace(exp.false() if kind == exp.And else exp.true())
 362                elif isinstance(b, kind):
 363                    # eliminate
 364                    rhs = b.unnest_operands()
 365                    ba, bb = rhs
 366
 367                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 368                        a.replace(aa)
 369                        b.replace(aa)
 370                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 371                        a.replace(ab)
 372                        b.replace(ab)
 373
 374    return expression
 375
 376
 377def propagate_constants(expression, root=True):
 378    """
 379    Propagate constants for conjunctions in DNF:
 380
 381    SELECT * FROM t WHERE a = b AND b = 5 becomes
 382    SELECT * FROM t WHERE a = 5 AND b = 5
 383
 384    Reference: https://www.sqlite.org/optoverview.html
 385    """
 386
 387    if (
 388        isinstance(expression, exp.And)
 389        and (root or not expression.same_parent)
 390        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 391    ):
 392        constant_mapping = {}
 393        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 394            if isinstance(expr, exp.EQ):
 395                l, r = expr.left, expr.right
 396
 397                # TODO: create a helper that can be used to detect nested literal expressions such
 398                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 399                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 400                    pass
 401                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 402                    l, r = r, l
 403                else:
 404                    continue
 405
 406                constant_mapping[l] = (id(l), r)
 407
 408        if constant_mapping:
 409            for column in find_all_in_scope(expression, exp.Column):
 410                parent = column.parent
 411                column_id, constant = constant_mapping.get(column) or (None, None)
 412                if (
 413                    column_id is not None
 414                    and id(column) != column_id
 415                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 416                ):
 417                    column.replace(constant.copy())
 418
 419    return expression
 420
 421
 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 423    exp.DateAdd: exp.Sub,
 424    exp.DateSub: exp.Add,
 425    exp.DatetimeAdd: exp.Sub,
 426    exp.DatetimeSub: exp.Add,
 427}
 428
 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 430    **INVERSE_DATE_OPS,
 431    exp.Add: exp.Sub,
 432    exp.Sub: exp.Add,
 433}
 434
 435
 436def _is_number(expression: exp.Expression) -> bool:
 437    return expression.is_number
 438
 439
 440def _is_interval(expression: exp.Expression) -> bool:
 441    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 442
 443
 444@catch(ModuleNotFoundError, UnsupportedUnit)
 445def simplify_equality(expression: exp.Expression) -> exp.Expression:
 446    """
 447    Use the subtraction and addition properties of equality to simplify expressions:
 448
 449        x + 1 = 3 becomes x = 2
 450
 451    There are two binary operations in the above expression: + and =
 452    Here's how we reference all the operands in the code below:
 453
 454          l     r
 455        x + 1 = 3
 456        a   b
 457    """
 458    if isinstance(expression, COMPARISONS):
 459        l, r = expression.left, expression.right
 460
 461        if l.__class__ in INVERSE_OPS:
 462            pass
 463        elif r.__class__ in INVERSE_OPS:
 464            l, r = r, l
 465        else:
 466            return expression
 467
 468        if r.is_number:
 469            a_predicate = _is_number
 470            b_predicate = _is_number
 471        elif _is_date_literal(r):
 472            a_predicate = _is_date_literal
 473            b_predicate = _is_interval
 474        else:
 475            return expression
 476
 477        if l.__class__ in INVERSE_DATE_OPS:
 478            l = t.cast(exp.IntervalOp, l)
 479            a = l.this
 480            b = l.interval()
 481        else:
 482            l = t.cast(exp.Binary, l)
 483            a, b = l.left, l.right
 484
 485        if not a_predicate(a) and b_predicate(b):
 486            pass
 487        elif not a_predicate(b) and b_predicate(a):
 488            a, b = b, a
 489        else:
 490            return expression
 491
 492        return expression.__class__(
 493            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 494        )
 495    return expression
 496
 497
 498def simplify_literals(expression, root=True):
 499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 500        return _flat_simplify(expression, _simplify_binary, root)
 501
 502    if isinstance(expression, exp.Neg):
 503        this = expression.this
 504        if this.is_number:
 505            value = this.name
 506            if value[0] == "-":
 507                return exp.Literal.number(value[1:])
 508            return exp.Literal.number(f"-{value}")
 509
 510    if type(expression) in INVERSE_DATE_OPS:
 511        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 512
 513    return expression
 514
 515
 516def _simplify_binary(expression, a, b):
 517    if isinstance(expression, exp.Is):
 518        if isinstance(b, exp.Not):
 519            c = b.this
 520            not_ = True
 521        else:
 522            c = b
 523            not_ = False
 524
 525        if is_null(c):
 526            if isinstance(a, exp.Literal):
 527                return exp.true() if not_ else exp.false()
 528            if is_null(a):
 529                return exp.false() if not_ else exp.true()
 530    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 531        return None
 532    elif is_null(a) or is_null(b):
 533        return exp.null()
 534
 535    if a.is_number and b.is_number:
 536        a = int(a.name) if a.is_int else Decimal(a.name)
 537        b = int(b.name) if b.is_int else Decimal(b.name)
 538
 539        if isinstance(expression, exp.Add):
 540            return exp.Literal.number(a + b)
 541        if isinstance(expression, exp.Sub):
 542            return exp.Literal.number(a - b)
 543        if isinstance(expression, exp.Mul):
 544            return exp.Literal.number(a * b)
 545        if isinstance(expression, exp.Div):
 546            # engines have differing int div behavior so intdiv is not safe
 547            if isinstance(a, int) and isinstance(b, int):
 548                return None
 549            return exp.Literal.number(a / b)
 550
 551        boolean = eval_boolean(expression, a, b)
 552
 553        if boolean:
 554            return boolean
 555    elif a.is_string and b.is_string:
 556        boolean = eval_boolean(expression, a.this, b.this)
 557
 558        if boolean:
 559            return boolean
 560    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 561        a, b = extract_date(a), extract_interval(b)
 562        if a and b:
 563            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 564                return date_literal(a + b)
 565            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 566                return date_literal(a - b)
 567    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 568        a, b = extract_interval(a), extract_date(b)
 569        # you cannot subtract a date from an interval
 570        if a and b and isinstance(expression, exp.Add):
 571            return date_literal(a + b)
 572    elif _is_date_literal(a) and _is_date_literal(b):
 573        if isinstance(expression, exp.Predicate):
 574            a, b = extract_date(a), extract_date(b)
 575            boolean = eval_boolean(expression, a, b)
 576            if boolean:
 577                return boolean
 578
 579    return None
 580
 581
 582def simplify_parens(expression):
 583    if not isinstance(expression, exp.Paren):
 584        return expression
 585
 586    this = expression.this
 587    parent = expression.parent
 588
 589    if not isinstance(this, exp.Select) and (
 590        not isinstance(parent, (exp.Condition, exp.Binary))
 591        or isinstance(parent, exp.Paren)
 592        or not isinstance(this, exp.Binary)
 593        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 594        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 595        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 596        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 597    ):
 598        return this
 599    return expression
 600
 601
 602NONNULL_CONSTANTS = (
 603    exp.Literal,
 604    exp.Boolean,
 605)
 606
 607CONSTANTS = (
 608    exp.Literal,
 609    exp.Boolean,
 610    exp.Null,
 611)
 612
 613
 614def _is_nonnull_constant(expression: exp.Expression) -> bool:
 615    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 616
 617
 618def _is_constant(expression: exp.Expression) -> bool:
 619    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 620
 621
 622def simplify_coalesce(expression):
 623    # COALESCE(x) -> x
 624    if (
 625        isinstance(expression, exp.Coalesce)
 626        and (not expression.expressions or _is_nonnull_constant(expression.this))
 627        # COALESCE is also used as a Spark partitioning hint
 628        and not isinstance(expression.parent, exp.Hint)
 629    ):
 630        return expression.this
 631
 632    if not isinstance(expression, COMPARISONS):
 633        return expression
 634
 635    if isinstance(expression.left, exp.Coalesce):
 636        coalesce = expression.left
 637        other = expression.right
 638    elif isinstance(expression.right, exp.Coalesce):
 639        coalesce = expression.right
 640        other = expression.left
 641    else:
 642        return expression
 643
 644    # This transformation is valid for non-constants,
 645    # but it really only does anything if they are both constants.
 646    if not _is_constant(other):
 647        return expression
 648
 649    # Find the first constant arg
 650    for arg_index, arg in enumerate(coalesce.expressions):
 651        if _is_constant(other):
 652            break
 653    else:
 654        return expression
 655
 656    coalesce.set("expressions", coalesce.expressions[:arg_index])
 657
 658    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 659    # since we already remove COALESCE at the top of this function.
 660    coalesce = coalesce if coalesce.expressions else coalesce.this
 661
 662    # This expression is more complex than when we started, but it will get simplified further
 663    return exp.paren(
 664        exp.or_(
 665            exp.and_(
 666                coalesce.is_(exp.null()).not_(copy=False),
 667                expression.copy(),
 668                copy=False,
 669            ),
 670            exp.and_(
 671                coalesce.is_(exp.null()),
 672                type(expression)(this=arg.copy(), expression=other.copy()),
 673                copy=False,
 674            ),
 675            copy=False,
 676        )
 677    )
 678
 679
 680CONCATS = (exp.Concat, exp.DPipe)
 681
 682
 683def simplify_concat(expression):
 684    """Reduces all groups that contain string literals by concatenating them."""
 685    if not isinstance(expression, CONCATS) or (
 686        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 687        isinstance(expression, exp.ConcatWs)
 688        and not expression.expressions[0].is_string
 689    ):
 690        return expression
 691
 692    if isinstance(expression, exp.ConcatWs):
 693        sep_expr, *expressions = expression.expressions
 694        sep = sep_expr.name
 695        concat_type = exp.ConcatWs
 696        args = {}
 697    else:
 698        expressions = expression.expressions
 699        sep = ""
 700        concat_type = exp.Concat
 701        args = {"safe": expression.args.get("safe")}
 702
 703    new_args = []
 704    for is_string_group, group in itertools.groupby(
 705        expressions or expression.flatten(), lambda e: e.is_string
 706    ):
 707        if is_string_group:
 708            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 709        else:
 710            new_args.extend(group)
 711
 712    if len(new_args) == 1 and new_args[0].is_string:
 713        return new_args[0]
 714
 715    if concat_type is exp.ConcatWs:
 716        new_args = [sep_expr] + new_args
 717
 718    return concat_type(expressions=new_args, **args)
 719
 720
 721def simplify_conditionals(expression):
 722    """Simplifies expressions like IF, CASE if their condition is statically known."""
 723    if isinstance(expression, exp.Case):
 724        this = expression.this
 725        for case in expression.args["ifs"]:
 726            cond = case.this
 727            if this:
 728                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 729                cond = cond.replace(this.pop().eq(cond))
 730
 731            if always_true(cond):
 732                return case.args["true"]
 733
 734            if always_false(cond):
 735                case.pop()
 736                if not expression.args["ifs"]:
 737                    return expression.args.get("default") or exp.null()
 738    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 739        if always_true(expression.this):
 740            return expression.args["true"]
 741        if always_false(expression.this):
 742            return expression.args.get("false") or exp.null()
 743
 744    return expression
 745
 746
 747DateRange = t.Tuple[datetime.date, datetime.date]
 748
 749
 750def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 751    """
 752    Get the date range for a DATE_TRUNC equality comparison:
 753
 754    Example:
 755        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 756    Returns:
 757        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 758    """
 759    floor = date_floor(date, unit)
 760
 761    if date != floor:
 762        # This will always be False, except for NULL values.
 763        return None
 764
 765    return floor, floor + interval(unit)
 766
 767
 768def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 769    """Get the logical expression for a date range"""
 770    return exp.and_(
 771        left >= date_literal(drange[0]),
 772        left < date_literal(drange[1]),
 773        copy=False,
 774    )
 775
 776
 777def _datetrunc_eq(
 778    left: exp.Expression, date: datetime.date, unit: str
 779) -> t.Optional[exp.Expression]:
 780    drange = _datetrunc_range(date, unit)
 781    if not drange:
 782        return None
 783
 784    return _datetrunc_eq_expression(left, drange)
 785
 786
 787def _datetrunc_neq(
 788    left: exp.Expression, date: datetime.date, unit: str
 789) -> t.Optional[exp.Expression]:
 790    drange = _datetrunc_range(date, unit)
 791    if not drange:
 792        return None
 793
 794    return exp.and_(
 795        left < date_literal(drange[0]),
 796        left >= date_literal(drange[1]),
 797        copy=False,
 798    )
 799
 800
 801DateTruncBinaryTransform = t.Callable[
 802    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 803]
 804DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 805    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 806    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 807    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 808    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 809    exp.EQ: _datetrunc_eq,
 810    exp.NEQ: _datetrunc_neq,
 811}
 812DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 813
 814
 815def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 816    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 817
 818
 819@catch(ModuleNotFoundError, UnsupportedUnit)
 820def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 821    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 822    comparison = expression.__class__
 823
 824    if comparison not in DATETRUNC_COMPARISONS:
 825        return expression
 826
 827    if isinstance(expression, exp.Binary):
 828        l, r = expression.left, expression.right
 829
 830        if _is_datetrunc_predicate(l, r):
 831            pass
 832        elif _is_datetrunc_predicate(r, l):
 833            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 834            l, r = r, l
 835        else:
 836            return expression
 837
 838        l = t.cast(exp.DateTrunc, l)
 839        unit = l.unit.name.lower()
 840        date = extract_date(r)
 841
 842        if not date:
 843            return expression
 844
 845        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 846    elif isinstance(expression, exp.In):
 847        l = expression.this
 848        rs = expression.expressions
 849
 850        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 851            l = t.cast(exp.DateTrunc, l)
 852            unit = l.unit.name.lower()
 853
 854            ranges = []
 855            for r in rs:
 856                date = extract_date(r)
 857                if not date:
 858                    return expression
 859                drange = _datetrunc_range(date, unit)
 860                if drange:
 861                    ranges.append(drange)
 862
 863            if not ranges:
 864                return expression
 865
 866            ranges = merge_ranges(ranges)
 867
 868            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 869
 870    return expression
 871
 872
 873# CROSS joins result in an empty table if the right table is empty.
 874# So we can only simplify certain types of joins to CROSS.
 875# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 876JOINS = {
 877    ("", ""),
 878    ("", "INNER"),
 879    ("RIGHT", ""),
 880    ("RIGHT", "OUTER"),
 881}
 882
 883
 884def remove_where_true(expression):
 885    for where in expression.find_all(exp.Where):
 886        if always_true(where.this):
 887            where.parent.set("where", None)
 888    for join in expression.find_all(exp.Join):
 889        if (
 890            always_true(join.args.get("on"))
 891            and not join.args.get("using")
 892            and not join.args.get("method")
 893            and (join.side, join.kind) in JOINS
 894        ):
 895            join.set("on", None)
 896            join.set("side", None)
 897            join.set("kind", "CROSS")
 898
 899
 900def always_true(expression):
 901    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 902        expression, exp.Literal
 903    )
 904
 905
 906def always_false(expression):
 907    return is_false(expression) or is_null(expression)
 908
 909
 910def is_complement(a, b):
 911    return isinstance(b, exp.Not) and b.this == a
 912
 913
 914def is_false(a: exp.Expression) -> bool:
 915    return type(a) is exp.Boolean and not a.this
 916
 917
 918def is_null(a: exp.Expression) -> bool:
 919    return type(a) is exp.Null
 920
 921
 922def eval_boolean(expression, a, b):
 923    if isinstance(expression, (exp.EQ, exp.Is)):
 924        return boolean_literal(a == b)
 925    if isinstance(expression, exp.NEQ):
 926        return boolean_literal(a != b)
 927    if isinstance(expression, exp.GT):
 928        return boolean_literal(a > b)
 929    if isinstance(expression, exp.GTE):
 930        return boolean_literal(a >= b)
 931    if isinstance(expression, exp.LT):
 932        return boolean_literal(a < b)
 933    if isinstance(expression, exp.LTE):
 934        return boolean_literal(a <= b)
 935    return None
 936
 937
 938def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 939    if isinstance(value, datetime.datetime):
 940        return value.date()
 941    if isinstance(value, datetime.date):
 942        return value
 943    try:
 944        return datetime.datetime.fromisoformat(value).date()
 945    except ValueError:
 946        return None
 947
 948
 949def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 950    if isinstance(value, datetime.datetime):
 951        return value
 952    if isinstance(value, datetime.date):
 953        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 954    try:
 955        return datetime.datetime.fromisoformat(value)
 956    except ValueError:
 957        return None
 958
 959
 960def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 961    if not value:
 962        return None
 963    if to.is_type(exp.DataType.Type.DATE):
 964        return cast_as_date(value)
 965    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 966        return cast_as_datetime(value)
 967    return None
 968
 969
 970def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 971    if isinstance(cast, exp.Cast):
 972        to = cast.to
 973    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
 974        to = exp.DataType.build(exp.DataType.Type.DATE)
 975    else:
 976        return None
 977
 978    if isinstance(cast.this, exp.Literal):
 979        value: t.Any = cast.this.name
 980    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 981        value = extract_date(cast.this)
 982    else:
 983        return None
 984    return cast_value(value, to)
 985
 986
 987def _is_date_literal(expression: exp.Expression) -> bool:
 988    return extract_date(expression) is not None
 989
 990
 991def extract_interval(expression):
 992    try:
 993        n = int(expression.name)
 994        unit = expression.text("unit").lower()
 995        return interval(unit, n)
 996    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
 997        return None
 998
 999
1000def date_literal(date):
1001    return exp.cast(
1002        exp.Literal.string(date),
1003        exp.DataType.Type.DATETIME
1004        if isinstance(date, datetime.datetime)
1005        else exp.DataType.Type.DATE,
1006    )
1007
1008
1009def interval(unit: str, n: int = 1):
1010    from dateutil.relativedelta import relativedelta
1011
1012    if unit == "year":
1013        return relativedelta(years=1 * n)
1014    if unit == "quarter":
1015        return relativedelta(months=3 * n)
1016    if unit == "month":
1017        return relativedelta(months=1 * n)
1018    if unit == "week":
1019        return relativedelta(weeks=1 * n)
1020    if unit == "day":
1021        return relativedelta(days=1 * n)
1022    if unit == "hour":
1023        return relativedelta(hours=1 * n)
1024    if unit == "minute":
1025        return relativedelta(minutes=1 * n)
1026    if unit == "second":
1027        return relativedelta(seconds=1 * n)
1028
1029    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1030
1031
1032def date_floor(d: datetime.date, unit: str) -> datetime.date:
1033    if unit == "year":
1034        return d.replace(month=1, day=1)
1035    if unit == "quarter":
1036        if d.month <= 3:
1037            return d.replace(month=1, day=1)
1038        elif d.month <= 6:
1039            return d.replace(month=4, day=1)
1040        elif d.month <= 9:
1041            return d.replace(month=7, day=1)
1042        else:
1043            return d.replace(month=10, day=1)
1044    if unit == "month":
1045        return d.replace(month=d.month, day=1)
1046    if unit == "week":
1047        # Assuming week starts on Monday (0) and ends on Sunday (6)
1048        return d - datetime.timedelta(days=d.weekday())
1049    if unit == "day":
1050        return d
1051
1052    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1053
1054
1055def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1056    floor = date_floor(d, unit)
1057
1058    if floor == d:
1059        return d
1060
1061    return floor + interval(unit)
1062
1063
1064def boolean_literal(condition):
1065    return exp.true() if condition else exp.false()
1066
1067
1068def _flat_simplify(expression, simplifier, root=True):
1069    if root or not expression.same_parent:
1070        operands = []
1071        queue = deque(expression.flatten(unnest=False))
1072        size = len(queue)
1073
1074        while queue:
1075            a = queue.popleft()
1076
1077            for b in queue:
1078                result = simplifier(expression, a, b)
1079
1080                if result and result is not expression:
1081                    queue.remove(b)
1082                    queue.appendleft(result)
1083                    break
1084            else:
1085                operands.append(a)
1086
1087        if len(operands) < size:
1088            return functools.reduce(
1089                lambda a, b: expression.__class__(this=a, expression=b), operands
1090            )
1091    return expression
1092
1093
1094def gen(expression: t.Any) -> str:
1095    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1096
1097    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1098    generator is expensive so we have a bare minimum sql generator here.
1099    """
1100    if expression is None:
1101        return "_"
1102    if is_iterable(expression):
1103        return ",".join(gen(e) for e in expression)
1104    if not isinstance(expression, exp.Expression):
1105        return str(expression)
1106
1107    etype = type(expression)
1108    if etype in GEN_MAP:
1109        return GEN_MAP[etype](expression)
1110    return f"{expression.key} {gen(expression.args.values())}"
1111
1112
1113GEN_MAP = {
1114    exp.Add: lambda e: _binary(e, "+"),
1115    exp.And: lambda e: _binary(e, "AND"),
1116    exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
1117    exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
1118    exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
1119    exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
1120    exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
1121    exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
1122    exp.Div: lambda e: _binary(e, "/"),
1123    exp.Dot: lambda e: _binary(e, "."),
1124    exp.EQ: lambda e: _binary(e, "="),
1125    exp.GT: lambda e: _binary(e, ">"),
1126    exp.GTE: lambda e: _binary(e, ">="),
1127    exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
1128    exp.ILike: lambda e: _binary(e, "ILIKE"),
1129    exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
1130    exp.Is: lambda e: _binary(e, "IS"),
1131    exp.Like: lambda e: _binary(e, "LIKE"),
1132    exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
1133    exp.LT: lambda e: _binary(e, "<"),
1134    exp.LTE: lambda e: _binary(e, "<="),
1135    exp.Mod: lambda e: _binary(e, "%"),
1136    exp.Mul: lambda e: _binary(e, "*"),
1137    exp.Neg: lambda e: _unary(e, "-"),
1138    exp.NEQ: lambda e: _binary(e, "<>"),
1139    exp.Not: lambda e: _unary(e, "NOT"),
1140    exp.Null: lambda e: "NULL",
1141    exp.Or: lambda e: _binary(e, "OR"),
1142    exp.Paren: lambda e: f"({gen(e.this)})",
1143    exp.Sub: lambda e: _binary(e, "-"),
1144    exp.Subquery: lambda e: f"({gen(e.args.values())})",
1145    exp.Table: lambda e: gen(e.args.values()),
1146    exp.Var: lambda e: e.name,
1147}
1148
1149
1150def _binary(e: exp.Binary, op: str) -> str:
1151    return f"{gen(e.left)} {op} {gen(e.right)}"
1152
1153
1154def _unary(e: exp.Unary, op: str) -> str:
1155    return f"{op} {gen(e.this)}"
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
18class UnsupportedUnit(Exception):
19    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
22def simplify(expression, constant_propagation=False):
23    """
24    Rewrite sqlglot AST to simplify expressions.
25
26    Example:
27        >>> import sqlglot
28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
29        >>> simplify(expression).sql()
30        'TRUE'
31
32    Args:
33        expression (sqlglot.Expression): expression to simplify
34        constant_propagation: whether or not the constant propagation rule should be used
35
36    Returns:
37        sqlglot.Expression: simplified expression
38    """
39
40    # group by expressions cannot be simplified, for example
41    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
42    # the projection must exactly match the group by key
43    for group in expression.find_all(exp.Group):
44        select = group.parent
45        groups = set(group.expressions)
46        group.meta[FINAL] = True
47
48        for e in select.selects:
49            for node, *_ in e.walk():
50                if node in groups:
51                    e.meta[FINAL] = True
52                    break
53
54        having = select.args.get("having")
55        if having:
56            for node, *_ in having.walk():
57                if node in groups:
58                    having.meta[FINAL] = True
59                    break
60
61    def _simplify(expression, root=True):
62        if expression.meta.get(FINAL):
63            return expression
64
65        # Pre-order transformations
66        node = expression
67        node = rewrite_between(node)
68        node = uniq_sort(node, root)
69        node = absorb_and_eliminate(node, root)
70        node = simplify_concat(node)
71        node = simplify_conditionals(node)
72
73        if constant_propagation:
74            node = propagate_constants(node, root)
75
76        exp.replace_children(node, lambda e: _simplify(e, False))
77
78        # Post-order transformations
79        node = simplify_not(node)
80        node = flatten(node)
81        node = simplify_connectors(node, root)
82        node = remove_complements(node, root)
83        node = simplify_coalesce(node)
84        node.parent = expression.parent
85        node = simplify_literals(node, root)
86        node = simplify_equality(node)
87        node = simplify_parens(node)
88        node = simplify_datetrunc_predicate(node)
89
90        if root:
91            expression.replace(node)
92
93        return node
94
95    expression = while_changing(expression, _simplify)
96    remove_where_true(expression)
97    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
100def catch(*exceptions):
101    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
102
103    def decorator(func):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
109
110        return wrapped
111
112    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
115def rewrite_between(expression: exp.Expression) -> exp.Expression:
116    """Rewrite x between y and z to x >= y AND x <= z.
117
118    This is done because comparison simplification is only done on lt/lte/gt/gte.
119    """
120    if isinstance(expression, exp.Between):
121        return exp.and_(
122            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
123            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
124            copy=False,
125        )
126    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
129def simplify_not(expression):
130    """
131    Demorgan's Law
132    NOT (x OR y) -> NOT x AND NOT y
133    NOT (x AND y) -> NOT x OR NOT y
134    """
135    if isinstance(expression, exp.Not):
136        if is_null(expression.this):
137            return exp.null()
138        if isinstance(expression.this, exp.Paren):
139            condition = expression.this.unnest()
140            if isinstance(condition, exp.And):
141                return exp.or_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if isinstance(condition, exp.Or):
147                return exp.and_(
148                    exp.not_(condition.left, copy=False),
149                    exp.not_(condition.right, copy=False),
150                    copy=False,
151                )
152            if is_null(condition):
153                return exp.null()
154        if always_true(expression.this):
155            return exp.false()
156        if is_false(expression.this):
157            return exp.true()
158        if isinstance(expression.this, exp.Not):
159            # double negation
160            # NOT NOT x -> x
161            return expression.this.this
162    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
165def flatten(expression):
166    """
167    A AND (B AND C) -> A AND B AND C
168    A OR (B OR C) -> A OR B OR C
169    """
170    if isinstance(expression, exp.Connector):
171        for node in expression.args.values():
172            child = node.unnest()
173            if isinstance(child, expression.__class__):
174                node.replace(child)
175    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
178def simplify_connectors(expression, root=True):
179    def _simplify_connectors(expression, left, right):
180        if left == right:
181            return left
182        if isinstance(expression, exp.And):
183            if is_false(left) or is_false(right):
184                return exp.false()
185            if is_null(left) or is_null(right):
186                return exp.null()
187            if always_true(left) and always_true(right):
188                return exp.true()
189            if always_true(left):
190                return right
191            if always_true(right):
192                return left
193            return _simplify_comparison(expression, left, right)
194        elif isinstance(expression, exp.Or):
195            if always_true(left) or always_true(right):
196                return exp.true()
197            if is_false(left) and is_false(right):
198                return exp.false()
199            if (
200                (is_null(left) and is_null(right))
201                or (is_null(left) and is_false(right))
202                or (is_false(left) and is_null(right))
203            ):
204                return exp.null()
205            if is_false(left):
206                return right
207            if is_false(right):
208                return left
209            return _simplify_comparison(expression, left, right, or_=True)
210
211    if isinstance(expression, exp.Connector):
212        return _flat_simplify(expression, _simplify_connectors, root)
213    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
296def remove_complements(expression, root=True):
297    """
298    Removing complements.
299
300    A AND NOT A -> FALSE
301    A OR NOT A -> TRUE
302    """
303    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
304        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
305
306        for a, b in itertools.permutations(expression.flatten(), 2):
307            if is_complement(a, b):
308                return complement
309    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
312def uniq_sort(expression, root=True):
313    """
314    Uniq and sort a connector.
315
316    C AND A AND B AND B -> A AND B AND C
317    """
318    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
319        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
320        flattened = tuple(expression.flatten())
321        deduped = {gen(e): e for e in flattened}
322        arr = tuple(deduped.items())
323
324        # check if the operands are already sorted, if not sort them
325        # A AND C AND B -> A AND B AND C
326        for i, (sql, e) in enumerate(arr[1:]):
327            if sql < arr[i][0]:
328                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
329                break
330        else:
331            # we didn't have to sort but maybe we need to dedup
332            if len(deduped) < len(flattened):
333                expression = result_func(*deduped.values(), copy=False)
334
335    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
338def absorb_and_eliminate(expression, root=True):
339    """
340    absorption:
341        A AND (A OR B) -> A
342        A OR (A AND B) -> A
343        A AND (NOT A OR B) -> A AND B
344        A OR (NOT A AND B) -> A OR B
345    elimination:
346        (A AND B) OR (A AND NOT B) -> A
347        (A OR B) AND (A OR NOT B) -> A
348    """
349    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
350        kind = exp.Or if isinstance(expression, exp.And) else exp.And
351
352        for a, b in itertools.permutations(expression.flatten(), 2):
353            if isinstance(a, kind):
354                aa, ab = a.unnest_operands()
355
356                # absorb
357                if is_complement(b, aa):
358                    aa.replace(exp.true() if kind == exp.And else exp.false())
359                elif is_complement(b, ab):
360                    ab.replace(exp.true() if kind == exp.And else exp.false())
361                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
362                    a.replace(exp.false() if kind == exp.And else exp.true())
363                elif isinstance(b, kind):
364                    # eliminate
365                    rhs = b.unnest_operands()
366                    ba, bb = rhs
367
368                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
369                        a.replace(aa)
370                        b.replace(aa)
371                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
372                        a.replace(ab)
373                        b.replace(ab)
374
375    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
378def propagate_constants(expression, root=True):
379    """
380    Propagate constants for conjunctions in DNF:
381
382    SELECT * FROM t WHERE a = b AND b = 5 becomes
383    SELECT * FROM t WHERE a = 5 AND b = 5
384
385    Reference: https://www.sqlite.org/optoverview.html
386    """
387
388    if (
389        isinstance(expression, exp.And)
390        and (root or not expression.same_parent)
391        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
392    ):
393        constant_mapping = {}
394        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
395            if isinstance(expr, exp.EQ):
396                l, r = expr.left, expr.right
397
398                # TODO: create a helper that can be used to detect nested literal expressions such
399                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
400                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
401                    pass
402                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
403                    l, r = r, l
404                else:
405                    continue
406
407                constant_mapping[l] = (id(l), r)
408
409        if constant_mapping:
410            for column in find_all_in_scope(expression, exp.Column):
411                parent = column.parent
412                column_id, constant = constant_mapping.get(column) or (None, None)
413                if (
414                    column_id is not None
415                    and id(column) != column_id
416                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
417                ):
418                    column.replace(constant.copy())
419
420    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
499def simplify_literals(expression, root=True):
500    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
501        return _flat_simplify(expression, _simplify_binary, root)
502
503    if isinstance(expression, exp.Neg):
504        this = expression.this
505        if this.is_number:
506            value = this.name
507            if value[0] == "-":
508                return exp.Literal.number(value[1:])
509            return exp.Literal.number(f"-{value}")
510
511    if type(expression) in INVERSE_DATE_OPS:
512        return _simplify_binary(expression, expression.this, expression.interval()) or expression
513
514    return expression
def simplify_parens(expression):
583def simplify_parens(expression):
584    if not isinstance(expression, exp.Paren):
585        return expression
586
587    this = expression.this
588    parent = expression.parent
589
590    if not isinstance(this, exp.Select) and (
591        not isinstance(parent, (exp.Condition, exp.Binary))
592        or isinstance(parent, exp.Paren)
593        or not isinstance(this, exp.Binary)
594        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
595        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
596        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
597        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
598    ):
599        return this
600    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
623def simplify_coalesce(expression):
624    # COALESCE(x) -> x
625    if (
626        isinstance(expression, exp.Coalesce)
627        and (not expression.expressions or _is_nonnull_constant(expression.this))
628        # COALESCE is also used as a Spark partitioning hint
629        and not isinstance(expression.parent, exp.Hint)
630    ):
631        return expression.this
632
633    if not isinstance(expression, COMPARISONS):
634        return expression
635
636    if isinstance(expression.left, exp.Coalesce):
637        coalesce = expression.left
638        other = expression.right
639    elif isinstance(expression.right, exp.Coalesce):
640        coalesce = expression.right
641        other = expression.left
642    else:
643        return expression
644
645    # This transformation is valid for non-constants,
646    # but it really only does anything if they are both constants.
647    if not _is_constant(other):
648        return expression
649
650    # Find the first constant arg
651    for arg_index, arg in enumerate(coalesce.expressions):
652        if _is_constant(other):
653            break
654    else:
655        return expression
656
657    coalesce.set("expressions", coalesce.expressions[:arg_index])
658
659    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
660    # since we already remove COALESCE at the top of this function.
661    coalesce = coalesce if coalesce.expressions else coalesce.this
662
663    # This expression is more complex than when we started, but it will get simplified further
664    return exp.paren(
665        exp.or_(
666            exp.and_(
667                coalesce.is_(exp.null()).not_(copy=False),
668                expression.copy(),
669                copy=False,
670            ),
671            exp.and_(
672                coalesce.is_(exp.null()),
673                type(expression)(this=arg.copy(), expression=other.copy()),
674                copy=False,
675            ),
676            copy=False,
677        )
678    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
684def simplify_concat(expression):
685    """Reduces all groups that contain string literals by concatenating them."""
686    if not isinstance(expression, CONCATS) or (
687        # We can't reduce a CONCAT_WS call if we don't statically know the separator
688        isinstance(expression, exp.ConcatWs)
689        and not expression.expressions[0].is_string
690    ):
691        return expression
692
693    if isinstance(expression, exp.ConcatWs):
694        sep_expr, *expressions = expression.expressions
695        sep = sep_expr.name
696        concat_type = exp.ConcatWs
697        args = {}
698    else:
699        expressions = expression.expressions
700        sep = ""
701        concat_type = exp.Concat
702        args = {"safe": expression.args.get("safe")}
703
704    new_args = []
705    for is_string_group, group in itertools.groupby(
706        expressions or expression.flatten(), lambda e: e.is_string
707    ):
708        if is_string_group:
709            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
710        else:
711            new_args.extend(group)
712
713    if len(new_args) == 1 and new_args[0].is_string:
714        return new_args[0]
715
716    if concat_type is exp.ConcatWs:
717        new_args = [sep_expr] + new_args
718
719    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
722def simplify_conditionals(expression):
723    """Simplifies expressions like IF, CASE if their condition is statically known."""
724    if isinstance(expression, exp.Case):
725        this = expression.this
726        for case in expression.args["ifs"]:
727            cond = case.this
728            if this:
729                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
730                cond = cond.replace(this.pop().eq(cond))
731
732            if always_true(cond):
733                return case.args["true"]
734
735            if always_false(cond):
736                case.pop()
737                if not expression.args["ifs"]:
738                    return expression.args.get("default") or exp.null()
739    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
740        if always_true(expression.this):
741            return expression.args["true"]
742        if always_false(expression.this):
743            return expression.args.get("false") or exp.null()
744
745    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('', ''), ('RIGHT', 'OUTER'), ('RIGHT', ''), ('', 'INNER')}
def remove_where_true(expression):
885def remove_where_true(expression):
886    for where in expression.find_all(exp.Where):
887        if always_true(where.this):
888            where.parent.set("where", None)
889    for join in expression.find_all(exp.Join):
890        if (
891            always_true(join.args.get("on"))
892            and not join.args.get("using")
893            and not join.args.get("method")
894            and (join.side, join.kind) in JOINS
895        ):
896            join.set("on", None)
897            join.set("side", None)
898            join.set("kind", "CROSS")
def always_true(expression):
901def always_true(expression):
902    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
903        expression, exp.Literal
904    )
def always_false(expression):
907def always_false(expression):
908    return is_false(expression) or is_null(expression)
def is_complement(a, b):
911def is_complement(a, b):
912    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
915def is_false(a: exp.Expression) -> bool:
916    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
919def is_null(a: exp.Expression) -> bool:
920    return type(a) is exp.Null
def eval_boolean(expression, a, b):
923def eval_boolean(expression, a, b):
924    if isinstance(expression, (exp.EQ, exp.Is)):
925        return boolean_literal(a == b)
926    if isinstance(expression, exp.NEQ):
927        return boolean_literal(a != b)
928    if isinstance(expression, exp.GT):
929        return boolean_literal(a > b)
930    if isinstance(expression, exp.GTE):
931        return boolean_literal(a >= b)
932    if isinstance(expression, exp.LT):
933        return boolean_literal(a < b)
934    if isinstance(expression, exp.LTE):
935        return boolean_literal(a <= b)
936    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
939def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
940    if isinstance(value, datetime.datetime):
941        return value.date()
942    if isinstance(value, datetime.date):
943        return value
944    try:
945        return datetime.datetime.fromisoformat(value).date()
946    except ValueError:
947        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
950def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
951    if isinstance(value, datetime.datetime):
952        return value
953    if isinstance(value, datetime.date):
954        return datetime.datetime(year=value.year, month=value.month, day=value.day)
955    try:
956        return datetime.datetime.fromisoformat(value)
957    except ValueError:
958        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
961def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
962    if not value:
963        return None
964    if to.is_type(exp.DataType.Type.DATE):
965        return cast_as_date(value)
966    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
967        return cast_as_datetime(value)
968    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
971def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
972    if isinstance(cast, exp.Cast):
973        to = cast.to
974    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
975        to = exp.DataType.build(exp.DataType.Type.DATE)
976    else:
977        return None
978
979    if isinstance(cast.this, exp.Literal):
980        value: t.Any = cast.this.name
981    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
982        value = extract_date(cast.this)
983    else:
984        return None
985    return cast_value(value, to)
def extract_interval(expression):
992def extract_interval(expression):
993    try:
994        n = int(expression.name)
995        unit = expression.text("unit").lower()
996        return interval(unit, n)
997    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
998        return None
def date_literal(date):
1001def date_literal(date):
1002    return exp.cast(
1003        exp.Literal.string(date),
1004        exp.DataType.Type.DATETIME
1005        if isinstance(date, datetime.datetime)
1006        else exp.DataType.Type.DATE,
1007    )
def interval(unit: str, n: int = 1):
1010def interval(unit: str, n: int = 1):
1011    from dateutil.relativedelta import relativedelta
1012
1013    if unit == "year":
1014        return relativedelta(years=1 * n)
1015    if unit == "quarter":
1016        return relativedelta(months=3 * n)
1017    if unit == "month":
1018        return relativedelta(months=1 * n)
1019    if unit == "week":
1020        return relativedelta(weeks=1 * n)
1021    if unit == "day":
1022        return relativedelta(days=1 * n)
1023    if unit == "hour":
1024        return relativedelta(hours=1 * n)
1025    if unit == "minute":
1026        return relativedelta(minutes=1 * n)
1027    if unit == "second":
1028        return relativedelta(seconds=1 * n)
1029
1030    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
1033def date_floor(d: datetime.date, unit: str) -> datetime.date:
1034    if unit == "year":
1035        return d.replace(month=1, day=1)
1036    if unit == "quarter":
1037        if d.month <= 3:
1038            return d.replace(month=1, day=1)
1039        elif d.month <= 6:
1040            return d.replace(month=4, day=1)
1041        elif d.month <= 9:
1042            return d.replace(month=7, day=1)
1043        else:
1044            return d.replace(month=10, day=1)
1045    if unit == "month":
1046        return d.replace(month=d.month, day=1)
1047    if unit == "week":
1048        # Assuming week starts on Monday (0) and ends on Sunday (6)
1049        return d - datetime.timedelta(days=d.weekday())
1050    if unit == "day":
1051        return d
1052
1053    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1056def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1057    floor = date_floor(d, unit)
1058
1059    if floor == d:
1060        return d
1061
1062    return floor + interval(unit)
def boolean_literal(condition):
1065def boolean_literal(condition):
1066    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1095def gen(expression: t.Any) -> str:
1096    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1097
1098    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1099    generator is expensive so we have a bare minimum sql generator here.
1100    """
1101    if expression is None:
1102        return "_"
1103    if is_iterable(expression):
1104        return ",".join(gen(e) for e in expression)
1105    if not isinstance(expression, exp.Expression):
1106        return str(expression)
1107
1108    etype = type(expression)
1109    if etype in GEN_MAP:
1110        return GEN_MAP[etype](expression)
1111    return f"{expression.key} {gen(expression.args.values())}"

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

GEN_MAP = {<class 'sqlglot.expressions.Add'>: <function <lambda>>, <class 'sqlglot.expressions.And'>: <function <lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function <lambda>>, <class 'sqlglot.expressions.Between'>: <function <lambda>>, <class 'sqlglot.expressions.Boolean'>: <function <lambda>>, <class 'sqlglot.expressions.Bracket'>: <function <lambda>>, <class 'sqlglot.expressions.Column'>: <function <lambda>>, <class 'sqlglot.expressions.DataType'>: <function <lambda>>, <class 'sqlglot.expressions.Div'>: <function <lambda>>, <class 'sqlglot.expressions.Dot'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.Identifier'>: <function <lambda>>, <class 'sqlglot.expressions.ILike'>: <function <lambda>>, <class 'sqlglot.expressions.In'>: <function <lambda>>, <class 'sqlglot.expressions.Is'>: <function <lambda>>, <class 'sqlglot.expressions.Like'>: <function <lambda>>, <class 'sqlglot.expressions.Literal'>: <function <lambda>>, <class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.Mod'>: <function <lambda>>, <class 'sqlglot.expressions.Mul'>: <function <lambda>>, <class 'sqlglot.expressions.Neg'>: <function <lambda>>, <class 'sqlglot.expressions.NEQ'>: <function <lambda>>, <class 'sqlglot.expressions.Not'>: <function <lambda>>, <class 'sqlglot.expressions.Null'>: <function <lambda>>, <class 'sqlglot.expressions.Or'>: <function <lambda>>, <class 'sqlglot.expressions.Paren'>: <function <lambda>>, <class 'sqlglot.expressions.Sub'>: <function <lambda>>, <class 'sqlglot.expressions.Subquery'>: <function <lambda>>, <class 'sqlglot.expressions.Table'>: <function <lambda>>, <class 'sqlglot.expressions.Var'>: <function <lambda>>}