Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.generator import cached_generator
  11from sqlglot.helper import first, merge_ranges, while_changing
  12from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  13
  14# Final means that an expression should not be simplified
  15FINAL = "final"
  16
  17
  18class UnsupportedUnit(Exception):
  19    pass
  20
  21
  22def simplify(expression, constant_propagation=False):
  23    """
  24    Rewrite sqlglot AST to simplify expressions.
  25
  26    Example:
  27        >>> import sqlglot
  28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  29        >>> simplify(expression).sql()
  30        'TRUE'
  31
  32    Args:
  33        expression (sqlglot.Expression): expression to simplify
  34        constant_propagation: whether or not the constant propagation rule should be used
  35
  36    Returns:
  37        sqlglot.Expression: simplified expression
  38    """
  39
  40    generate = cached_generator()
  41
  42    # group by expressions cannot be simplified, for example
  43    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  44    # the projection must exactly match the group by key
  45    for group in expression.find_all(exp.Group):
  46        select = group.parent
  47        groups = set(group.expressions)
  48        group.meta[FINAL] = True
  49
  50        for e in select.selects:
  51            for node, *_ in e.walk():
  52                if node in groups:
  53                    e.meta[FINAL] = True
  54                    break
  55
  56        having = select.args.get("having")
  57        if having:
  58            for node, *_ in having.walk():
  59                if node in groups:
  60                    having.meta[FINAL] = True
  61                    break
  62
  63    def _simplify(expression, root=True):
  64        if expression.meta.get(FINAL):
  65            return expression
  66
  67        # Pre-order transformations
  68        node = expression
  69        node = rewrite_between(node)
  70        node = uniq_sort(node, generate, root)
  71        node = absorb_and_eliminate(node, root)
  72        node = simplify_concat(node)
  73
  74        if constant_propagation:
  75            node = propagate_constants(node, root)
  76
  77        exp.replace_children(node, lambda e: _simplify(e, False))
  78
  79        # Post-order transformations
  80        node = simplify_not(node)
  81        node = flatten(node)
  82        node = simplify_connectors(node, root)
  83        node = remove_complements(node, root)
  84        node = simplify_coalesce(node)
  85        node.parent = expression.parent
  86        node = simplify_literals(node, root)
  87        node = simplify_equality(node)
  88        node = simplify_parens(node)
  89        node = simplify_datetrunc_predicate(node)
  90
  91        if root:
  92            expression.replace(node)
  93
  94        return node
  95
  96    expression = while_changing(expression, _simplify)
  97    remove_where_true(expression)
  98    return expression
  99
 100
 101def catch(*exceptions):
 102    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 103
 104    def decorator(func):
 105        def wrapped(expression, *args, **kwargs):
 106            try:
 107                return func(expression, *args, **kwargs)
 108            except exceptions:
 109                return expression
 110
 111        return wrapped
 112
 113    return decorator
 114
 115
 116def rewrite_between(expression: exp.Expression) -> exp.Expression:
 117    """Rewrite x between y and z to x >= y AND x <= z.
 118
 119    This is done because comparison simplification is only done on lt/lte/gt/gte.
 120    """
 121    if isinstance(expression, exp.Between):
 122        return exp.and_(
 123            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 124            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 125            copy=False,
 126        )
 127    return expression
 128
 129
 130def simplify_not(expression):
 131    """
 132    Demorgan's Law
 133    NOT (x OR y) -> NOT x AND NOT y
 134    NOT (x AND y) -> NOT x OR NOT y
 135    """
 136    if isinstance(expression, exp.Not):
 137        if is_null(expression.this):
 138            return exp.null()
 139        if isinstance(expression.this, exp.Paren):
 140            condition = expression.this.unnest()
 141            if isinstance(condition, exp.And):
 142                return exp.or_(
 143                    exp.not_(condition.left, copy=False),
 144                    exp.not_(condition.right, copy=False),
 145                    copy=False,
 146                )
 147            if isinstance(condition, exp.Or):
 148                return exp.and_(
 149                    exp.not_(condition.left, copy=False),
 150                    exp.not_(condition.right, copy=False),
 151                    copy=False,
 152                )
 153            if is_null(condition):
 154                return exp.null()
 155        if always_true(expression.this):
 156            return exp.false()
 157        if is_false(expression.this):
 158            return exp.true()
 159        if isinstance(expression.this, exp.Not):
 160            # double negation
 161            # NOT NOT x -> x
 162            return expression.this.this
 163    return expression
 164
 165
 166def flatten(expression):
 167    """
 168    A AND (B AND C) -> A AND B AND C
 169    A OR (B OR C) -> A OR B OR C
 170    """
 171    if isinstance(expression, exp.Connector):
 172        for node in expression.args.values():
 173            child = node.unnest()
 174            if isinstance(child, expression.__class__):
 175                node.replace(child)
 176    return expression
 177
 178
 179def simplify_connectors(expression, root=True):
 180    def _simplify_connectors(expression, left, right):
 181        if left == right:
 182            return left
 183        if isinstance(expression, exp.And):
 184            if is_false(left) or is_false(right):
 185                return exp.false()
 186            if is_null(left) or is_null(right):
 187                return exp.null()
 188            if always_true(left) and always_true(right):
 189                return exp.true()
 190            if always_true(left):
 191                return right
 192            if always_true(right):
 193                return left
 194            return _simplify_comparison(expression, left, right)
 195        elif isinstance(expression, exp.Or):
 196            if always_true(left) or always_true(right):
 197                return exp.true()
 198            if is_false(left) and is_false(right):
 199                return exp.false()
 200            if (
 201                (is_null(left) and is_null(right))
 202                or (is_null(left) and is_false(right))
 203                or (is_false(left) and is_null(right))
 204            ):
 205                return exp.null()
 206            if is_false(left):
 207                return right
 208            if is_false(right):
 209                return left
 210            return _simplify_comparison(expression, left, right, or_=True)
 211
 212    if isinstance(expression, exp.Connector):
 213        return _flat_simplify(expression, _simplify_connectors, root)
 214    return expression
 215
 216
 217LT_LTE = (exp.LT, exp.LTE)
 218GT_GTE = (exp.GT, exp.GTE)
 219
 220COMPARISONS = (
 221    *LT_LTE,
 222    *GT_GTE,
 223    exp.EQ,
 224    exp.NEQ,
 225    exp.Is,
 226)
 227
 228INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 229    exp.LT: exp.GT,
 230    exp.GT: exp.LT,
 231    exp.LTE: exp.GTE,
 232    exp.GTE: exp.LTE,
 233}
 234
 235
 236def _simplify_comparison(expression, left, right, or_=False):
 237    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 238        ll, lr = left.args.values()
 239        rl, rr = right.args.values()
 240
 241        largs = {ll, lr}
 242        rargs = {rl, rr}
 243
 244        matching = largs & rargs
 245        columns = {m for m in matching if isinstance(m, exp.Column)}
 246
 247        if matching and columns:
 248            try:
 249                l = first(largs - columns)
 250                r = first(rargs - columns)
 251            except StopIteration:
 252                return expression
 253
 254            # make sure the comparison is always of the form x > 1 instead of 1 < x
 255            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 256                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 257            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 258                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 259
 260            if l.is_number and r.is_number:
 261                l = float(l.name)
 262                r = float(r.name)
 263            elif l.is_string and r.is_string:
 264                l = l.name
 265                r = r.name
 266            else:
 267                return None
 268
 269            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 270                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 271                    return left if (av > bv if or_ else av <= bv) else right
 272                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 273                    return left if (av < bv if or_ else av >= bv) else right
 274
 275                # we can't ever shortcut to true because the column could be null
 276                if not or_:
 277                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 278                        if av <= bv:
 279                            return exp.false()
 280                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 281                        if av >= bv:
 282                            return exp.false()
 283                    elif isinstance(a, exp.EQ):
 284                        if isinstance(b, exp.LT):
 285                            return exp.false() if av >= bv else a
 286                        if isinstance(b, exp.LTE):
 287                            return exp.false() if av > bv else a
 288                        if isinstance(b, exp.GT):
 289                            return exp.false() if av <= bv else a
 290                        if isinstance(b, exp.GTE):
 291                            return exp.false() if av < bv else a
 292                        if isinstance(b, exp.NEQ):
 293                            return exp.false() if av == bv else a
 294    return None
 295
 296
 297def remove_complements(expression, root=True):
 298    """
 299    Removing complements.
 300
 301    A AND NOT A -> FALSE
 302    A OR NOT A -> TRUE
 303    """
 304    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 305        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 306
 307        for a, b in itertools.permutations(expression.flatten(), 2):
 308            if is_complement(a, b):
 309                return complement
 310    return expression
 311
 312
 313def uniq_sort(expression, generate, root=True):
 314    """
 315    Uniq and sort a connector.
 316
 317    C AND A AND B AND B -> A AND B AND C
 318    """
 319    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 320        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 321        flattened = tuple(expression.flatten())
 322        deduped = {generate(e): e for e in flattened}
 323        arr = tuple(deduped.items())
 324
 325        # check if the operands are already sorted, if not sort them
 326        # A AND C AND B -> A AND B AND C
 327        for i, (sql, e) in enumerate(arr[1:]):
 328            if sql < arr[i][0]:
 329                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 330                break
 331        else:
 332            # we didn't have to sort but maybe we need to dedup
 333            if len(deduped) < len(flattened):
 334                expression = result_func(*deduped.values(), copy=False)
 335
 336    return expression
 337
 338
 339def absorb_and_eliminate(expression, root=True):
 340    """
 341    absorption:
 342        A AND (A OR B) -> A
 343        A OR (A AND B) -> A
 344        A AND (NOT A OR B) -> A AND B
 345        A OR (NOT A AND B) -> A OR B
 346    elimination:
 347        (A AND B) OR (A AND NOT B) -> A
 348        (A OR B) AND (A OR NOT B) -> A
 349    """
 350    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 351        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 352
 353        for a, b in itertools.permutations(expression.flatten(), 2):
 354            if isinstance(a, kind):
 355                aa, ab = a.unnest_operands()
 356
 357                # absorb
 358                if is_complement(b, aa):
 359                    aa.replace(exp.true() if kind == exp.And else exp.false())
 360                elif is_complement(b, ab):
 361                    ab.replace(exp.true() if kind == exp.And else exp.false())
 362                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 363                    a.replace(exp.false() if kind == exp.And else exp.true())
 364                elif isinstance(b, kind):
 365                    # eliminate
 366                    rhs = b.unnest_operands()
 367                    ba, bb = rhs
 368
 369                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 370                        a.replace(aa)
 371                        b.replace(aa)
 372                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 373                        a.replace(ab)
 374                        b.replace(ab)
 375
 376    return expression
 377
 378
 379def propagate_constants(expression, root=True):
 380    """
 381    Propagate constants for conjunctions in DNF:
 382
 383    SELECT * FROM t WHERE a = b AND b = 5 becomes
 384    SELECT * FROM t WHERE a = 5 AND b = 5
 385
 386    Reference: https://www.sqlite.org/optoverview.html
 387    """
 388
 389    if (
 390        isinstance(expression, exp.And)
 391        and (root or not expression.same_parent)
 392        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 393    ):
 394        constant_mapping = {}
 395        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 396            if isinstance(expr, exp.EQ):
 397                l, r = expr.left, expr.right
 398
 399                # TODO: create a helper that can be used to detect nested literal expressions such
 400                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 401                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 402                    pass
 403                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 404                    l, r = r, l
 405                else:
 406                    continue
 407
 408                constant_mapping[l] = (id(l), r)
 409
 410        if constant_mapping:
 411            for column in find_all_in_scope(expression, exp.Column):
 412                parent = column.parent
 413                column_id, constant = constant_mapping.get(column) or (None, None)
 414                if (
 415                    column_id is not None
 416                    and id(column) != column_id
 417                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 418                ):
 419                    column.replace(constant.copy())
 420
 421    return expression
 422
 423
 424INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 425    exp.DateAdd: exp.Sub,
 426    exp.DateSub: exp.Add,
 427    exp.DatetimeAdd: exp.Sub,
 428    exp.DatetimeSub: exp.Add,
 429}
 430
 431INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 432    **INVERSE_DATE_OPS,
 433    exp.Add: exp.Sub,
 434    exp.Sub: exp.Add,
 435}
 436
 437
 438def _is_number(expression: exp.Expression) -> bool:
 439    return expression.is_number
 440
 441
 442def _is_interval(expression: exp.Expression) -> bool:
 443    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 444
 445
 446@catch(ModuleNotFoundError, UnsupportedUnit)
 447def simplify_equality(expression: exp.Expression) -> exp.Expression:
 448    """
 449    Use the subtraction and addition properties of equality to simplify expressions:
 450
 451        x + 1 = 3 becomes x = 2
 452
 453    There are two binary operations in the above expression: + and =
 454    Here's how we reference all the operands in the code below:
 455
 456          l     r
 457        x + 1 = 3
 458        a   b
 459    """
 460    if isinstance(expression, COMPARISONS):
 461        l, r = expression.left, expression.right
 462
 463        if l.__class__ in INVERSE_OPS:
 464            pass
 465        elif r.__class__ in INVERSE_OPS:
 466            l, r = r, l
 467        else:
 468            return expression
 469
 470        if r.is_number:
 471            a_predicate = _is_number
 472            b_predicate = _is_number
 473        elif _is_date_literal(r):
 474            a_predicate = _is_date_literal
 475            b_predicate = _is_interval
 476        else:
 477            return expression
 478
 479        if l.__class__ in INVERSE_DATE_OPS:
 480            a = l.this
 481            b = l.interval()
 482        else:
 483            a, b = l.left, l.right
 484
 485        if not a_predicate(a) and b_predicate(b):
 486            pass
 487        elif not a_predicate(b) and b_predicate(a):
 488            a, b = b, a
 489        else:
 490            return expression
 491
 492        return expression.__class__(
 493            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 494        )
 495    return expression
 496
 497
 498def simplify_literals(expression, root=True):
 499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 500        return _flat_simplify(expression, _simplify_binary, root)
 501
 502    if isinstance(expression, exp.Neg):
 503        this = expression.this
 504        if this.is_number:
 505            value = this.name
 506            if value[0] == "-":
 507                return exp.Literal.number(value[1:])
 508            return exp.Literal.number(f"-{value}")
 509
 510    return expression
 511
 512
 513def _simplify_binary(expression, a, b):
 514    if isinstance(expression, exp.Is):
 515        if isinstance(b, exp.Not):
 516            c = b.this
 517            not_ = True
 518        else:
 519            c = b
 520            not_ = False
 521
 522        if is_null(c):
 523            if isinstance(a, exp.Literal):
 524                return exp.true() if not_ else exp.false()
 525            if is_null(a):
 526                return exp.false() if not_ else exp.true()
 527    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 528        return None
 529    elif is_null(a) or is_null(b):
 530        return exp.null()
 531
 532    if a.is_number and b.is_number:
 533        a = int(a.name) if a.is_int else Decimal(a.name)
 534        b = int(b.name) if b.is_int else Decimal(b.name)
 535
 536        if isinstance(expression, exp.Add):
 537            return exp.Literal.number(a + b)
 538        if isinstance(expression, exp.Sub):
 539            return exp.Literal.number(a - b)
 540        if isinstance(expression, exp.Mul):
 541            return exp.Literal.number(a * b)
 542        if isinstance(expression, exp.Div):
 543            # engines have differing int div behavior so intdiv is not safe
 544            if isinstance(a, int) and isinstance(b, int):
 545                return None
 546            return exp.Literal.number(a / b)
 547
 548        boolean = eval_boolean(expression, a, b)
 549
 550        if boolean:
 551            return boolean
 552    elif a.is_string and b.is_string:
 553        boolean = eval_boolean(expression, a.this, b.this)
 554
 555        if boolean:
 556            return boolean
 557    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 558        a, b = extract_date(a), extract_interval(b)
 559        if a and b:
 560            if isinstance(expression, exp.Add):
 561                return date_literal(a + b)
 562            if isinstance(expression, exp.Sub):
 563                return date_literal(a - b)
 564    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 565        a, b = extract_interval(a), extract_date(b)
 566        # you cannot subtract a date from an interval
 567        if a and b and isinstance(expression, exp.Add):
 568            return date_literal(a + b)
 569
 570    return None
 571
 572
 573def simplify_parens(expression):
 574    if not isinstance(expression, exp.Paren):
 575        return expression
 576
 577    this = expression.this
 578    parent = expression.parent
 579
 580    if not isinstance(this, exp.Select) and (
 581        not isinstance(parent, (exp.Condition, exp.Binary))
 582        or isinstance(parent, exp.Paren)
 583        or not isinstance(this, exp.Binary)
 584        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 585        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 586        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 587        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 588    ):
 589        return this
 590    return expression
 591
 592
 593CONSTANTS = (
 594    exp.Literal,
 595    exp.Boolean,
 596    exp.Null,
 597)
 598
 599
 600def simplify_coalesce(expression):
 601    # COALESCE(x) -> x
 602    if (
 603        isinstance(expression, exp.Coalesce)
 604        and not expression.expressions
 605        # COALESCE is also used as a Spark partitioning hint
 606        and not isinstance(expression.parent, exp.Hint)
 607    ):
 608        return expression.this
 609
 610    if not isinstance(expression, COMPARISONS):
 611        return expression
 612
 613    if isinstance(expression.left, exp.Coalesce):
 614        coalesce = expression.left
 615        other = expression.right
 616    elif isinstance(expression.right, exp.Coalesce):
 617        coalesce = expression.right
 618        other = expression.left
 619    else:
 620        return expression
 621
 622    # This transformation is valid for non-constants,
 623    # but it really only does anything if they are both constants.
 624    if not isinstance(other, CONSTANTS):
 625        return expression
 626
 627    # Find the first constant arg
 628    for arg_index, arg in enumerate(coalesce.expressions):
 629        if isinstance(arg, CONSTANTS):
 630            break
 631    else:
 632        return expression
 633
 634    coalesce.set("expressions", coalesce.expressions[:arg_index])
 635
 636    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 637    # since we already remove COALESCE at the top of this function.
 638    coalesce = coalesce if coalesce.expressions else coalesce.this
 639
 640    # This expression is more complex than when we started, but it will get simplified further
 641    return exp.paren(
 642        exp.or_(
 643            exp.and_(
 644                coalesce.is_(exp.null()).not_(copy=False),
 645                expression.copy(),
 646                copy=False,
 647            ),
 648            exp.and_(
 649                coalesce.is_(exp.null()),
 650                type(expression)(this=arg.copy(), expression=other.copy()),
 651                copy=False,
 652            ),
 653            copy=False,
 654        )
 655    )
 656
 657
 658CONCATS = (exp.Concat, exp.DPipe)
 659SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
 660
 661
 662def simplify_concat(expression):
 663    """Reduces all groups that contain string literals by concatenating them."""
 664    if not isinstance(expression, CONCATS) or (
 665        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 666        isinstance(expression, exp.ConcatWs)
 667        and not expression.expressions[0].is_string
 668    ):
 669        return expression
 670
 671    if isinstance(expression, exp.ConcatWs):
 672        sep_expr, *expressions = expression.expressions
 673        sep = sep_expr.name
 674        concat_type = exp.ConcatWs
 675    else:
 676        expressions = expression.expressions
 677        sep = ""
 678        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
 679
 680    new_args = []
 681    for is_string_group, group in itertools.groupby(
 682        expressions or expression.flatten(), lambda e: e.is_string
 683    ):
 684        if is_string_group:
 685            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 686        else:
 687            new_args.extend(group)
 688
 689    if len(new_args) == 1 and new_args[0].is_string:
 690        return new_args[0]
 691
 692    if concat_type is exp.ConcatWs:
 693        new_args = [sep_expr] + new_args
 694
 695    return concat_type(expressions=new_args)
 696
 697
 698DateRange = t.Tuple[datetime.date, datetime.date]
 699
 700
 701def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 702    """
 703    Get the date range for a DATE_TRUNC equality comparison:
 704
 705    Example:
 706        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 707    Returns:
 708        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 709    """
 710    floor = date_floor(date, unit)
 711
 712    if date != floor:
 713        # This will always be False, except for NULL values.
 714        return None
 715
 716    return floor, floor + interval(unit)
 717
 718
 719def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 720    """Get the logical expression for a date range"""
 721    return exp.and_(
 722        left >= date_literal(drange[0]),
 723        left < date_literal(drange[1]),
 724        copy=False,
 725    )
 726
 727
 728def _datetrunc_eq(
 729    left: exp.Expression, date: datetime.date, unit: str
 730) -> t.Optional[exp.Expression]:
 731    drange = _datetrunc_range(date, unit)
 732    if not drange:
 733        return None
 734
 735    return _datetrunc_eq_expression(left, drange)
 736
 737
 738def _datetrunc_neq(
 739    left: exp.Expression, date: datetime.date, unit: str
 740) -> t.Optional[exp.Expression]:
 741    drange = _datetrunc_range(date, unit)
 742    if not drange:
 743        return None
 744
 745    return exp.and_(
 746        left < date_literal(drange[0]),
 747        left >= date_literal(drange[1]),
 748        copy=False,
 749    )
 750
 751
 752DateTruncBinaryTransform = t.Callable[
 753    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 754]
 755DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 756    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 757    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 758    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 759    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 760    exp.EQ: _datetrunc_eq,
 761    exp.NEQ: _datetrunc_neq,
 762}
 763DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 764
 765
 766def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 767    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 768
 769
 770@catch(ModuleNotFoundError, UnsupportedUnit)
 771def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 772    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 773    comparison = expression.__class__
 774
 775    if comparison not in DATETRUNC_COMPARISONS:
 776        return expression
 777
 778    if isinstance(expression, exp.Binary):
 779        l, r = expression.left, expression.right
 780
 781        if _is_datetrunc_predicate(l, r):
 782            pass
 783        elif _is_datetrunc_predicate(r, l):
 784            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 785            l, r = r, l
 786        else:
 787            return expression
 788
 789        unit = l.unit.name.lower()
 790        date = extract_date(r)
 791
 792        if not date:
 793            return expression
 794
 795        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 796    elif isinstance(expression, exp.In):
 797        l = expression.this
 798        rs = expression.expressions
 799
 800        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 801            unit = l.unit.name.lower()
 802
 803            ranges = []
 804            for r in rs:
 805                date = extract_date(r)
 806                if not date:
 807                    return expression
 808                drange = _datetrunc_range(date, unit)
 809                if drange:
 810                    ranges.append(drange)
 811
 812            if not ranges:
 813                return expression
 814
 815            ranges = merge_ranges(ranges)
 816
 817            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 818
 819    return expression
 820
 821
 822# CROSS joins result in an empty table if the right table is empty.
 823# So we can only simplify certain types of joins to CROSS.
 824# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 825JOINS = {
 826    ("", ""),
 827    ("", "INNER"),
 828    ("RIGHT", ""),
 829    ("RIGHT", "OUTER"),
 830}
 831
 832
 833def remove_where_true(expression):
 834    for where in expression.find_all(exp.Where):
 835        if always_true(where.this):
 836            where.parent.set("where", None)
 837    for join in expression.find_all(exp.Join):
 838        if (
 839            always_true(join.args.get("on"))
 840            and not join.args.get("using")
 841            and not join.args.get("method")
 842            and (join.side, join.kind) in JOINS
 843        ):
 844            join.set("on", None)
 845            join.set("side", None)
 846            join.set("kind", "CROSS")
 847
 848
 849def always_true(expression):
 850    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 851        expression, exp.Literal
 852    )
 853
 854
 855def is_complement(a, b):
 856    return isinstance(b, exp.Not) and b.this == a
 857
 858
 859def is_false(a: exp.Expression) -> bool:
 860    return type(a) is exp.Boolean and not a.this
 861
 862
 863def is_null(a: exp.Expression) -> bool:
 864    return type(a) is exp.Null
 865
 866
 867def eval_boolean(expression, a, b):
 868    if isinstance(expression, (exp.EQ, exp.Is)):
 869        return boolean_literal(a == b)
 870    if isinstance(expression, exp.NEQ):
 871        return boolean_literal(a != b)
 872    if isinstance(expression, exp.GT):
 873        return boolean_literal(a > b)
 874    if isinstance(expression, exp.GTE):
 875        return boolean_literal(a >= b)
 876    if isinstance(expression, exp.LT):
 877        return boolean_literal(a < b)
 878    if isinstance(expression, exp.LTE):
 879        return boolean_literal(a <= b)
 880    return None
 881
 882
 883def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 884    if isinstance(value, datetime.datetime):
 885        return value.date()
 886    if isinstance(value, datetime.date):
 887        return value
 888    try:
 889        return datetime.datetime.fromisoformat(value).date()
 890    except ValueError:
 891        return None
 892
 893
 894def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 895    if isinstance(value, datetime.datetime):
 896        return value
 897    if isinstance(value, datetime.date):
 898        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 899    try:
 900        return datetime.datetime.fromisoformat(value)
 901    except ValueError:
 902        return None
 903
 904
 905def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 906    if not value:
 907        return None
 908    if to.is_type(exp.DataType.Type.DATE):
 909        return cast_as_date(value)
 910    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 911        return cast_as_datetime(value)
 912    return None
 913
 914
 915def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 916    if isinstance(cast, exp.Cast):
 917        to = cast.to
 918    elif isinstance(cast, exp.TsOrDsToDate):
 919        to = exp.DataType.build(exp.DataType.Type.DATE)
 920    else:
 921        return None
 922
 923    if isinstance(cast.this, exp.Literal):
 924        value: t.Any = cast.this.name
 925    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 926        value = extract_date(cast.this)
 927    else:
 928        return None
 929    return cast_value(value, to)
 930
 931
 932def _is_date_literal(expression: exp.Expression) -> bool:
 933    return extract_date(expression) is not None
 934
 935
 936def extract_interval(expression):
 937    n = int(expression.name)
 938    unit = expression.text("unit").lower()
 939
 940    try:
 941        return interval(unit, n)
 942    except (UnsupportedUnit, ModuleNotFoundError):
 943        return None
 944
 945
 946def date_literal(date):
 947    return exp.cast(
 948        exp.Literal.string(date),
 949        exp.DataType.Type.DATETIME
 950        if isinstance(date, datetime.datetime)
 951        else exp.DataType.Type.DATE,
 952    )
 953
 954
 955def interval(unit: str, n: int = 1):
 956    from dateutil.relativedelta import relativedelta
 957
 958    if unit == "year":
 959        return relativedelta(years=1 * n)
 960    if unit == "quarter":
 961        return relativedelta(months=3 * n)
 962    if unit == "month":
 963        return relativedelta(months=1 * n)
 964    if unit == "week":
 965        return relativedelta(weeks=1 * n)
 966    if unit == "day":
 967        return relativedelta(days=1 * n)
 968    if unit == "hour":
 969        return relativedelta(hours=1 * n)
 970    if unit == "minute":
 971        return relativedelta(minutes=1 * n)
 972    if unit == "second":
 973        return relativedelta(seconds=1 * n)
 974
 975    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 976
 977
 978def date_floor(d: datetime.date, unit: str) -> datetime.date:
 979    if unit == "year":
 980        return d.replace(month=1, day=1)
 981    if unit == "quarter":
 982        if d.month <= 3:
 983            return d.replace(month=1, day=1)
 984        elif d.month <= 6:
 985            return d.replace(month=4, day=1)
 986        elif d.month <= 9:
 987            return d.replace(month=7, day=1)
 988        else:
 989            return d.replace(month=10, day=1)
 990    if unit == "month":
 991        return d.replace(month=d.month, day=1)
 992    if unit == "week":
 993        # Assuming week starts on Monday (0) and ends on Sunday (6)
 994        return d - datetime.timedelta(days=d.weekday())
 995    if unit == "day":
 996        return d
 997
 998    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 999
1000
1001def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1002    floor = date_floor(d, unit)
1003
1004    if floor == d:
1005        return d
1006
1007    return floor + interval(unit)
1008
1009
1010def boolean_literal(condition):
1011    return exp.true() if condition else exp.false()
1012
1013
1014def _flat_simplify(expression, simplifier, root=True):
1015    if root or not expression.same_parent:
1016        operands = []
1017        queue = deque(expression.flatten(unnest=False))
1018        size = len(queue)
1019
1020        while queue:
1021            a = queue.popleft()
1022
1023            for b in queue:
1024                result = simplifier(expression, a, b)
1025
1026                if result and result is not expression:
1027                    queue.remove(b)
1028                    queue.appendleft(result)
1029                    break
1030            else:
1031                operands.append(a)
1032
1033        if len(operands) < size:
1034            return functools.reduce(
1035                lambda a, b: expression.__class__(this=a, expression=b), operands
1036            )
1037    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
19class UnsupportedUnit(Exception):
20    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
23def simplify(expression, constant_propagation=False):
24    """
25    Rewrite sqlglot AST to simplify expressions.
26
27    Example:
28        >>> import sqlglot
29        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
30        >>> simplify(expression).sql()
31        'TRUE'
32
33    Args:
34        expression (sqlglot.Expression): expression to simplify
35        constant_propagation: whether or not the constant propagation rule should be used
36
37    Returns:
38        sqlglot.Expression: simplified expression
39    """
40
41    generate = cached_generator()
42
43    # group by expressions cannot be simplified, for example
44    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
45    # the projection must exactly match the group by key
46    for group in expression.find_all(exp.Group):
47        select = group.parent
48        groups = set(group.expressions)
49        group.meta[FINAL] = True
50
51        for e in select.selects:
52            for node, *_ in e.walk():
53                if node in groups:
54                    e.meta[FINAL] = True
55                    break
56
57        having = select.args.get("having")
58        if having:
59            for node, *_ in having.walk():
60                if node in groups:
61                    having.meta[FINAL] = True
62                    break
63
64    def _simplify(expression, root=True):
65        if expression.meta.get(FINAL):
66            return expression
67
68        # Pre-order transformations
69        node = expression
70        node = rewrite_between(node)
71        node = uniq_sort(node, generate, root)
72        node = absorb_and_eliminate(node, root)
73        node = simplify_concat(node)
74
75        if constant_propagation:
76            node = propagate_constants(node, root)
77
78        exp.replace_children(node, lambda e: _simplify(e, False))
79
80        # Post-order transformations
81        node = simplify_not(node)
82        node = flatten(node)
83        node = simplify_connectors(node, root)
84        node = remove_complements(node, root)
85        node = simplify_coalesce(node)
86        node.parent = expression.parent
87        node = simplify_literals(node, root)
88        node = simplify_equality(node)
89        node = simplify_parens(node)
90        node = simplify_datetrunc_predicate(node)
91
92        if root:
93            expression.replace(node)
94
95        return node
96
97    expression = while_changing(expression, _simplify)
98    remove_where_true(expression)
99    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
102def catch(*exceptions):
103    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
104
105    def decorator(func):
106        def wrapped(expression, *args, **kwargs):
107            try:
108                return func(expression, *args, **kwargs)
109            except exceptions:
110                return expression
111
112        return wrapped
113
114    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
117def rewrite_between(expression: exp.Expression) -> exp.Expression:
118    """Rewrite x between y and z to x >= y AND x <= z.
119
120    This is done because comparison simplification is only done on lt/lte/gt/gte.
121    """
122    if isinstance(expression, exp.Between):
123        return exp.and_(
124            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
125            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
126            copy=False,
127        )
128    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
131def simplify_not(expression):
132    """
133    Demorgan's Law
134    NOT (x OR y) -> NOT x AND NOT y
135    NOT (x AND y) -> NOT x OR NOT y
136    """
137    if isinstance(expression, exp.Not):
138        if is_null(expression.this):
139            return exp.null()
140        if isinstance(expression.this, exp.Paren):
141            condition = expression.this.unnest()
142            if isinstance(condition, exp.And):
143                return exp.or_(
144                    exp.not_(condition.left, copy=False),
145                    exp.not_(condition.right, copy=False),
146                    copy=False,
147                )
148            if isinstance(condition, exp.Or):
149                return exp.and_(
150                    exp.not_(condition.left, copy=False),
151                    exp.not_(condition.right, copy=False),
152                    copy=False,
153                )
154            if is_null(condition):
155                return exp.null()
156        if always_true(expression.this):
157            return exp.false()
158        if is_false(expression.this):
159            return exp.true()
160        if isinstance(expression.this, exp.Not):
161            # double negation
162            # NOT NOT x -> x
163            return expression.this.this
164    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
167def flatten(expression):
168    """
169    A AND (B AND C) -> A AND B AND C
170    A OR (B OR C) -> A OR B OR C
171    """
172    if isinstance(expression, exp.Connector):
173        for node in expression.args.values():
174            child = node.unnest()
175            if isinstance(child, expression.__class__):
176                node.replace(child)
177    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
180def simplify_connectors(expression, root=True):
181    def _simplify_connectors(expression, left, right):
182        if left == right:
183            return left
184        if isinstance(expression, exp.And):
185            if is_false(left) or is_false(right):
186                return exp.false()
187            if is_null(left) or is_null(right):
188                return exp.null()
189            if always_true(left) and always_true(right):
190                return exp.true()
191            if always_true(left):
192                return right
193            if always_true(right):
194                return left
195            return _simplify_comparison(expression, left, right)
196        elif isinstance(expression, exp.Or):
197            if always_true(left) or always_true(right):
198                return exp.true()
199            if is_false(left) and is_false(right):
200                return exp.false()
201            if (
202                (is_null(left) and is_null(right))
203                or (is_null(left) and is_false(right))
204                or (is_false(left) and is_null(right))
205            ):
206                return exp.null()
207            if is_false(left):
208                return right
209            if is_false(right):
210                return left
211            return _simplify_comparison(expression, left, right, or_=True)
212
213    if isinstance(expression, exp.Connector):
214        return _flat_simplify(expression, _simplify_connectors, root)
215    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
298def remove_complements(expression, root=True):
299    """
300    Removing complements.
301
302    A AND NOT A -> FALSE
303    A OR NOT A -> TRUE
304    """
305    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
306        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
307
308        for a, b in itertools.permutations(expression.flatten(), 2):
309            if is_complement(a, b):
310                return complement
311    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
314def uniq_sort(expression, generate, root=True):
315    """
316    Uniq and sort a connector.
317
318    C AND A AND B AND B -> A AND B AND C
319    """
320    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
321        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
322        flattened = tuple(expression.flatten())
323        deduped = {generate(e): e for e in flattened}
324        arr = tuple(deduped.items())
325
326        # check if the operands are already sorted, if not sort them
327        # A AND C AND B -> A AND B AND C
328        for i, (sql, e) in enumerate(arr[1:]):
329            if sql < arr[i][0]:
330                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
331                break
332        else:
333            # we didn't have to sort but maybe we need to dedup
334            if len(deduped) < len(flattened):
335                expression = result_func(*deduped.values(), copy=False)
336
337    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
340def absorb_and_eliminate(expression, root=True):
341    """
342    absorption:
343        A AND (A OR B) -> A
344        A OR (A AND B) -> A
345        A AND (NOT A OR B) -> A AND B
346        A OR (NOT A AND B) -> A OR B
347    elimination:
348        (A AND B) OR (A AND NOT B) -> A
349        (A OR B) AND (A OR NOT B) -> A
350    """
351    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
352        kind = exp.Or if isinstance(expression, exp.And) else exp.And
353
354        for a, b in itertools.permutations(expression.flatten(), 2):
355            if isinstance(a, kind):
356                aa, ab = a.unnest_operands()
357
358                # absorb
359                if is_complement(b, aa):
360                    aa.replace(exp.true() if kind == exp.And else exp.false())
361                elif is_complement(b, ab):
362                    ab.replace(exp.true() if kind == exp.And else exp.false())
363                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
364                    a.replace(exp.false() if kind == exp.And else exp.true())
365                elif isinstance(b, kind):
366                    # eliminate
367                    rhs = b.unnest_operands()
368                    ba, bb = rhs
369
370                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
371                        a.replace(aa)
372                        b.replace(aa)
373                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
374                        a.replace(ab)
375                        b.replace(ab)
376
377    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
380def propagate_constants(expression, root=True):
381    """
382    Propagate constants for conjunctions in DNF:
383
384    SELECT * FROM t WHERE a = b AND b = 5 becomes
385    SELECT * FROM t WHERE a = 5 AND b = 5
386
387    Reference: https://www.sqlite.org/optoverview.html
388    """
389
390    if (
391        isinstance(expression, exp.And)
392        and (root or not expression.same_parent)
393        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
394    ):
395        constant_mapping = {}
396        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
397            if isinstance(expr, exp.EQ):
398                l, r = expr.left, expr.right
399
400                # TODO: create a helper that can be used to detect nested literal expressions such
401                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
402                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
403                    pass
404                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
405                    l, r = r, l
406                else:
407                    continue
408
409                constant_mapping[l] = (id(l), r)
410
411        if constant_mapping:
412            for column in find_all_in_scope(expression, exp.Column):
413                parent = column.parent
414                column_id, constant = constant_mapping.get(column) or (None, None)
415                if (
416                    column_id is not None
417                    and id(column) != column_id
418                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
419                ):
420                    column.replace(constant.copy())
421
422    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
106        def wrapped(expression, *args, **kwargs):
107            try:
108                return func(expression, *args, **kwargs)
109            except exceptions:
110                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
499def simplify_literals(expression, root=True):
500    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
501        return _flat_simplify(expression, _simplify_binary, root)
502
503    if isinstance(expression, exp.Neg):
504        this = expression.this
505        if this.is_number:
506            value = this.name
507            if value[0] == "-":
508                return exp.Literal.number(value[1:])
509            return exp.Literal.number(f"-{value}")
510
511    return expression
def simplify_parens(expression):
574def simplify_parens(expression):
575    if not isinstance(expression, exp.Paren):
576        return expression
577
578    this = expression.this
579    parent = expression.parent
580
581    if not isinstance(this, exp.Select) and (
582        not isinstance(parent, (exp.Condition, exp.Binary))
583        or isinstance(parent, exp.Paren)
584        or not isinstance(this, exp.Binary)
585        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
586        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
587        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
588        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
589    ):
590        return this
591    return expression
def simplify_coalesce(expression):
601def simplify_coalesce(expression):
602    # COALESCE(x) -> x
603    if (
604        isinstance(expression, exp.Coalesce)
605        and not expression.expressions
606        # COALESCE is also used as a Spark partitioning hint
607        and not isinstance(expression.parent, exp.Hint)
608    ):
609        return expression.this
610
611    if not isinstance(expression, COMPARISONS):
612        return expression
613
614    if isinstance(expression.left, exp.Coalesce):
615        coalesce = expression.left
616        other = expression.right
617    elif isinstance(expression.right, exp.Coalesce):
618        coalesce = expression.right
619        other = expression.left
620    else:
621        return expression
622
623    # This transformation is valid for non-constants,
624    # but it really only does anything if they are both constants.
625    if not isinstance(other, CONSTANTS):
626        return expression
627
628    # Find the first constant arg
629    for arg_index, arg in enumerate(coalesce.expressions):
630        if isinstance(arg, CONSTANTS):
631            break
632    else:
633        return expression
634
635    coalesce.set("expressions", coalesce.expressions[:arg_index])
636
637    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
638    # since we already remove COALESCE at the top of this function.
639    coalesce = coalesce if coalesce.expressions else coalesce.this
640
641    # This expression is more complex than when we started, but it will get simplified further
642    return exp.paren(
643        exp.or_(
644            exp.and_(
645                coalesce.is_(exp.null()).not_(copy=False),
646                expression.copy(),
647                copy=False,
648            ),
649            exp.and_(
650                coalesce.is_(exp.null()),
651                type(expression)(this=arg.copy(), expression=other.copy()),
652                copy=False,
653            ),
654            copy=False,
655        )
656    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
663def simplify_concat(expression):
664    """Reduces all groups that contain string literals by concatenating them."""
665    if not isinstance(expression, CONCATS) or (
666        # We can't reduce a CONCAT_WS call if we don't statically know the separator
667        isinstance(expression, exp.ConcatWs)
668        and not expression.expressions[0].is_string
669    ):
670        return expression
671
672    if isinstance(expression, exp.ConcatWs):
673        sep_expr, *expressions = expression.expressions
674        sep = sep_expr.name
675        concat_type = exp.ConcatWs
676    else:
677        expressions = expression.expressions
678        sep = ""
679        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
680
681    new_args = []
682    for is_string_group, group in itertools.groupby(
683        expressions or expression.flatten(), lambda e: e.is_string
684    ):
685        if is_string_group:
686            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
687        else:
688            new_args.extend(group)
689
690    if len(new_args) == 1 and new_args[0].is_string:
691        return new_args[0]
692
693    if concat_type is exp.ConcatWs:
694        new_args = [sep_expr] + new_args
695
696    return concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
106        def wrapped(expression, *args, **kwargs):
107            try:
108                return func(expression, *args, **kwargs)
109            except exceptions:
110                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', 'OUTER'), ('', 'INNER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
834def remove_where_true(expression):
835    for where in expression.find_all(exp.Where):
836        if always_true(where.this):
837            where.parent.set("where", None)
838    for join in expression.find_all(exp.Join):
839        if (
840            always_true(join.args.get("on"))
841            and not join.args.get("using")
842            and not join.args.get("method")
843            and (join.side, join.kind) in JOINS
844        ):
845            join.set("on", None)
846            join.set("side", None)
847            join.set("kind", "CROSS")
def always_true(expression):
850def always_true(expression):
851    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
852        expression, exp.Literal
853    )
def is_complement(a, b):
856def is_complement(a, b):
857    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
860def is_false(a: exp.Expression) -> bool:
861    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
864def is_null(a: exp.Expression) -> bool:
865    return type(a) is exp.Null
def eval_boolean(expression, a, b):
868def eval_boolean(expression, a, b):
869    if isinstance(expression, (exp.EQ, exp.Is)):
870        return boolean_literal(a == b)
871    if isinstance(expression, exp.NEQ):
872        return boolean_literal(a != b)
873    if isinstance(expression, exp.GT):
874        return boolean_literal(a > b)
875    if isinstance(expression, exp.GTE):
876        return boolean_literal(a >= b)
877    if isinstance(expression, exp.LT):
878        return boolean_literal(a < b)
879    if isinstance(expression, exp.LTE):
880        return boolean_literal(a <= b)
881    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
884def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
885    if isinstance(value, datetime.datetime):
886        return value.date()
887    if isinstance(value, datetime.date):
888        return value
889    try:
890        return datetime.datetime.fromisoformat(value).date()
891    except ValueError:
892        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
895def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
896    if isinstance(value, datetime.datetime):
897        return value
898    if isinstance(value, datetime.date):
899        return datetime.datetime(year=value.year, month=value.month, day=value.day)
900    try:
901        return datetime.datetime.fromisoformat(value)
902    except ValueError:
903        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
906def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
907    if not value:
908        return None
909    if to.is_type(exp.DataType.Type.DATE):
910        return cast_as_date(value)
911    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
912        return cast_as_datetime(value)
913    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
916def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
917    if isinstance(cast, exp.Cast):
918        to = cast.to
919    elif isinstance(cast, exp.TsOrDsToDate):
920        to = exp.DataType.build(exp.DataType.Type.DATE)
921    else:
922        return None
923
924    if isinstance(cast.this, exp.Literal):
925        value: t.Any = cast.this.name
926    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
927        value = extract_date(cast.this)
928    else:
929        return None
930    return cast_value(value, to)
def extract_interval(expression):
937def extract_interval(expression):
938    n = int(expression.name)
939    unit = expression.text("unit").lower()
940
941    try:
942        return interval(unit, n)
943    except (UnsupportedUnit, ModuleNotFoundError):
944        return None
def date_literal(date):
947def date_literal(date):
948    return exp.cast(
949        exp.Literal.string(date),
950        exp.DataType.Type.DATETIME
951        if isinstance(date, datetime.datetime)
952        else exp.DataType.Type.DATE,
953    )
def interval(unit: str, n: int = 1):
956def interval(unit: str, n: int = 1):
957    from dateutil.relativedelta import relativedelta
958
959    if unit == "year":
960        return relativedelta(years=1 * n)
961    if unit == "quarter":
962        return relativedelta(months=3 * n)
963    if unit == "month":
964        return relativedelta(months=1 * n)
965    if unit == "week":
966        return relativedelta(weeks=1 * n)
967    if unit == "day":
968        return relativedelta(days=1 * n)
969    if unit == "hour":
970        return relativedelta(hours=1 * n)
971    if unit == "minute":
972        return relativedelta(minutes=1 * n)
973    if unit == "second":
974        return relativedelta(seconds=1 * n)
975
976    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
979def date_floor(d: datetime.date, unit: str) -> datetime.date:
980    if unit == "year":
981        return d.replace(month=1, day=1)
982    if unit == "quarter":
983        if d.month <= 3:
984            return d.replace(month=1, day=1)
985        elif d.month <= 6:
986            return d.replace(month=4, day=1)
987        elif d.month <= 9:
988            return d.replace(month=7, day=1)
989        else:
990            return d.replace(month=10, day=1)
991    if unit == "month":
992        return d.replace(month=d.month, day=1)
993    if unit == "week":
994        # Assuming week starts on Monday (0) and ends on Sunday (6)
995        return d - datetime.timedelta(days=d.weekday())
996    if unit == "day":
997        return d
998
999    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1002def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1003    floor = date_floor(d, unit)
1004
1005    if floor == d:
1006        return d
1007
1008    return floor + interval(unit)
def boolean_literal(condition):
1011def boolean_literal(condition):
1012    return exp.true() if condition else exp.false()