Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25# Value ranges for byte-sized signed/unsigned integers
  26TINYINT_MIN = -128
  27TINYINT_MAX = 127
  28UTINYINT_MIN = 0
  29UTINYINT_MAX = 255
  30
  31
  32class UnsupportedUnit(Exception):
  33    pass
  34
  35
  36def simplify(
  37    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  38):
  39    """
  40    Rewrite sqlglot AST to simplify expressions.
  41
  42    Example:
  43        >>> import sqlglot
  44        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  45        >>> simplify(expression).sql()
  46        'TRUE'
  47
  48    Args:
  49        expression (sqlglot.Expression): expression to simplify
  50        constant_propagation: whether the constant propagation rule should be used
  51
  52    Returns:
  53        sqlglot.Expression: simplified expression
  54    """
  55
  56    dialect = Dialect.get_or_raise(dialect)
  57
  58    def _simplify(expression, root=True):
  59        if expression.meta.get(FINAL):
  60            return expression
  61
  62        # group by expressions cannot be simplified, for example
  63        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  64        # the projection must exactly match the group by key
  65        group = expression.args.get("group")
  66
  67        if group and hasattr(expression, "selects"):
  68            groups = set(group.expressions)
  69            group.meta[FINAL] = True
  70
  71            for e in expression.selects:
  72                for node in e.walk():
  73                    if node in groups:
  74                        e.meta[FINAL] = True
  75                        break
  76
  77            having = expression.args.get("having")
  78            if having:
  79                for node in having.walk():
  80                    if node in groups:
  81                        having.meta[FINAL] = True
  82                        break
  83
  84        # Pre-order transformations
  85        node = expression
  86        node = rewrite_between(node)
  87        node = uniq_sort(node, root)
  88        node = absorb_and_eliminate(node, root)
  89        node = simplify_concat(node)
  90        node = simplify_conditionals(node)
  91
  92        if constant_propagation:
  93            node = propagate_constants(node, root)
  94
  95        exp.replace_children(node, lambda e: _simplify(e, False))
  96
  97        # Post-order transformations
  98        node = simplify_not(node)
  99        node = flatten(node)
 100        node = simplify_connectors(node, root)
 101        node = remove_complements(node, root)
 102        node = simplify_coalesce(node)
 103        node.parent = expression.parent
 104        node = simplify_literals(node, root)
 105        node = simplify_equality(node)
 106        node = simplify_parens(node)
 107        node = simplify_datetrunc(node, dialect)
 108        node = sort_comparison(node)
 109        node = simplify_startswith(node)
 110
 111        if root:
 112            expression.replace(node)
 113        return node
 114
 115    expression = while_changing(expression, _simplify)
 116    remove_where_true(expression)
 117    return expression
 118
 119
 120def catch(*exceptions):
 121    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 122
 123    def decorator(func):
 124        def wrapped(expression, *args, **kwargs):
 125            try:
 126                return func(expression, *args, **kwargs)
 127            except exceptions:
 128                return expression
 129
 130        return wrapped
 131
 132    return decorator
 133
 134
 135def rewrite_between(expression: exp.Expression) -> exp.Expression:
 136    """Rewrite x between y and z to x >= y AND x <= z.
 137
 138    This is done because comparison simplification is only done on lt/lte/gt/gte.
 139    """
 140    if isinstance(expression, exp.Between):
 141        negate = isinstance(expression.parent, exp.Not)
 142
 143        expression = exp.and_(
 144            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 145            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 146            copy=False,
 147        )
 148
 149        if negate:
 150            expression = exp.paren(expression, copy=False)
 151
 152    return expression
 153
 154
 155COMPLEMENT_COMPARISONS = {
 156    exp.LT: exp.GTE,
 157    exp.GT: exp.LTE,
 158    exp.LTE: exp.GT,
 159    exp.GTE: exp.LT,
 160    exp.EQ: exp.NEQ,
 161    exp.NEQ: exp.EQ,
 162}
 163
 164
 165def simplify_not(expression):
 166    """
 167    Demorgan's Law
 168    NOT (x OR y) -> NOT x AND NOT y
 169    NOT (x AND y) -> NOT x OR NOT y
 170    """
 171    if isinstance(expression, exp.Not):
 172        this = expression.this
 173        if is_null(this):
 174            return exp.null()
 175        if this.__class__ in COMPLEMENT_COMPARISONS:
 176            return COMPLEMENT_COMPARISONS[this.__class__](
 177                this=this.this, expression=this.expression
 178            )
 179        if isinstance(this, exp.Paren):
 180            condition = this.unnest()
 181            if isinstance(condition, exp.And):
 182                return exp.paren(
 183                    exp.or_(
 184                        exp.not_(condition.left, copy=False),
 185                        exp.not_(condition.right, copy=False),
 186                        copy=False,
 187                    )
 188                )
 189            if isinstance(condition, exp.Or):
 190                return exp.paren(
 191                    exp.and_(
 192                        exp.not_(condition.left, copy=False),
 193                        exp.not_(condition.right, copy=False),
 194                        copy=False,
 195                    )
 196                )
 197            if is_null(condition):
 198                return exp.null()
 199        if always_true(this):
 200            return exp.false()
 201        if is_false(this):
 202            return exp.true()
 203        if isinstance(this, exp.Not):
 204            # double negation
 205            # NOT NOT x -> x
 206            return this.this
 207    return expression
 208
 209
 210def flatten(expression):
 211    """
 212    A AND (B AND C) -> A AND B AND C
 213    A OR (B OR C) -> A OR B OR C
 214    """
 215    if isinstance(expression, exp.Connector):
 216        for node in expression.args.values():
 217            child = node.unnest()
 218            if isinstance(child, expression.__class__):
 219                node.replace(child)
 220    return expression
 221
 222
 223def simplify_connectors(expression, root=True):
 224    def _simplify_connectors(expression, left, right):
 225        if left == right:
 226            return left
 227        if isinstance(expression, exp.And):
 228            if is_false(left) or is_false(right):
 229                return exp.false()
 230            if is_null(left) or is_null(right):
 231                return exp.null()
 232            if always_true(left) and always_true(right):
 233                return exp.true()
 234            if always_true(left):
 235                return right
 236            if always_true(right):
 237                return left
 238            return _simplify_comparison(expression, left, right)
 239        elif isinstance(expression, exp.Or):
 240            if always_true(left) or always_true(right):
 241                return exp.true()
 242            if is_false(left) and is_false(right):
 243                return exp.false()
 244            if (
 245                (is_null(left) and is_null(right))
 246                or (is_null(left) and is_false(right))
 247                or (is_false(left) and is_null(right))
 248            ):
 249                return exp.null()
 250            if is_false(left):
 251                return right
 252            if is_false(right):
 253                return left
 254            return _simplify_comparison(expression, left, right, or_=True)
 255
 256    if isinstance(expression, exp.Connector):
 257        return _flat_simplify(expression, _simplify_connectors, root)
 258    return expression
 259
 260
 261LT_LTE = (exp.LT, exp.LTE)
 262GT_GTE = (exp.GT, exp.GTE)
 263
 264COMPARISONS = (
 265    *LT_LTE,
 266    *GT_GTE,
 267    exp.EQ,
 268    exp.NEQ,
 269    exp.Is,
 270)
 271
 272INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 273    exp.LT: exp.GT,
 274    exp.GT: exp.LT,
 275    exp.LTE: exp.GTE,
 276    exp.GTE: exp.LTE,
 277}
 278
 279NONDETERMINISTIC = (exp.Rand, exp.Randn)
 280
 281
 282def _simplify_comparison(expression, left, right, or_=False):
 283    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 284        ll, lr = left.args.values()
 285        rl, rr = right.args.values()
 286
 287        largs = {ll, lr}
 288        rargs = {rl, rr}
 289
 290        matching = largs & rargs
 291        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 292
 293        if matching and columns:
 294            try:
 295                l = first(largs - columns)
 296                r = first(rargs - columns)
 297            except StopIteration:
 298                return expression
 299
 300            if l.is_number and r.is_number:
 301                l = float(l.name)
 302                r = float(r.name)
 303            elif l.is_string and r.is_string:
 304                l = l.name
 305                r = r.name
 306            else:
 307                l = extract_date(l)
 308                if not l:
 309                    return None
 310                r = extract_date(r)
 311                if not r:
 312                    return None
 313
 314            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 315                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 316                    return left if (av > bv if or_ else av <= bv) else right
 317                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 318                    return left if (av < bv if or_ else av >= bv) else right
 319
 320                # we can't ever shortcut to true because the column could be null
 321                if not or_:
 322                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 323                        if av <= bv:
 324                            return exp.false()
 325                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 326                        if av >= bv:
 327                            return exp.false()
 328                    elif isinstance(a, exp.EQ):
 329                        if isinstance(b, exp.LT):
 330                            return exp.false() if av >= bv else a
 331                        if isinstance(b, exp.LTE):
 332                            return exp.false() if av > bv else a
 333                        if isinstance(b, exp.GT):
 334                            return exp.false() if av <= bv else a
 335                        if isinstance(b, exp.GTE):
 336                            return exp.false() if av < bv else a
 337                        if isinstance(b, exp.NEQ):
 338                            return exp.false() if av == bv else a
 339    return None
 340
 341
 342def remove_complements(expression, root=True):
 343    """
 344    Removing complements.
 345
 346    A AND NOT A -> FALSE
 347    A OR NOT A -> TRUE
 348    """
 349    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 350        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 351
 352        for a, b in itertools.permutations(expression.flatten(), 2):
 353            if is_complement(a, b):
 354                return complement
 355    return expression
 356
 357
 358def uniq_sort(expression, root=True):
 359    """
 360    Uniq and sort a connector.
 361
 362    C AND A AND B AND B -> A AND B AND C
 363    """
 364    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 365        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 366        flattened = tuple(expression.flatten())
 367        deduped = {gen(e): e for e in flattened}
 368        arr = tuple(deduped.items())
 369
 370        # check if the operands are already sorted, if not sort them
 371        # A AND C AND B -> A AND B AND C
 372        for i, (sql, e) in enumerate(arr[1:]):
 373            if sql < arr[i][0]:
 374                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 375                break
 376        else:
 377            # we didn't have to sort but maybe we need to dedup
 378            if len(deduped) < len(flattened):
 379                expression = result_func(*deduped.values(), copy=False)
 380
 381    return expression
 382
 383
 384def absorb_and_eliminate(expression, root=True):
 385    """
 386    absorption:
 387        A AND (A OR B) -> A
 388        A OR (A AND B) -> A
 389        A AND (NOT A OR B) -> A AND B
 390        A OR (NOT A AND B) -> A OR B
 391    elimination:
 392        (A AND B) OR (A AND NOT B) -> A
 393        (A OR B) AND (A OR NOT B) -> A
 394    """
 395    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 396        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 397
 398        for a, b in itertools.permutations(expression.flatten(), 2):
 399            if isinstance(a, kind):
 400                aa, ab = a.unnest_operands()
 401
 402                # absorb
 403                if is_complement(b, aa):
 404                    aa.replace(exp.true() if kind == exp.And else exp.false())
 405                elif is_complement(b, ab):
 406                    ab.replace(exp.true() if kind == exp.And else exp.false())
 407                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 408                    a.replace(exp.false() if kind == exp.And else exp.true())
 409                elif isinstance(b, kind):
 410                    # eliminate
 411                    rhs = b.unnest_operands()
 412                    ba, bb = rhs
 413
 414                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 415                        a.replace(aa)
 416                        b.replace(aa)
 417                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 418                        a.replace(ab)
 419                        b.replace(ab)
 420
 421    return expression
 422
 423
 424def propagate_constants(expression, root=True):
 425    """
 426    Propagate constants for conjunctions in DNF:
 427
 428    SELECT * FROM t WHERE a = b AND b = 5 becomes
 429    SELECT * FROM t WHERE a = 5 AND b = 5
 430
 431    Reference: https://www.sqlite.org/optoverview.html
 432    """
 433
 434    if (
 435        isinstance(expression, exp.And)
 436        and (root or not expression.same_parent)
 437        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 438    ):
 439        constant_mapping = {}
 440        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 441            if isinstance(expr, exp.EQ):
 442                l, r = expr.left, expr.right
 443
 444                # TODO: create a helper that can be used to detect nested literal expressions such
 445                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 446                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 447                    constant_mapping[l] = (id(l), r)
 448
 449        if constant_mapping:
 450            for column in find_all_in_scope(expression, exp.Column):
 451                parent = column.parent
 452                column_id, constant = constant_mapping.get(column) or (None, None)
 453                if (
 454                    column_id is not None
 455                    and id(column) != column_id
 456                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 457                ):
 458                    column.replace(constant.copy())
 459
 460    return expression
 461
 462
 463INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 464    exp.DateAdd: exp.Sub,
 465    exp.DateSub: exp.Add,
 466    exp.DatetimeAdd: exp.Sub,
 467    exp.DatetimeSub: exp.Add,
 468}
 469
 470INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 471    **INVERSE_DATE_OPS,
 472    exp.Add: exp.Sub,
 473    exp.Sub: exp.Add,
 474}
 475
 476
 477def _is_number(expression: exp.Expression) -> bool:
 478    return expression.is_number
 479
 480
 481def _is_interval(expression: exp.Expression) -> bool:
 482    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 483
 484
 485@catch(ModuleNotFoundError, UnsupportedUnit)
 486def simplify_equality(expression: exp.Expression) -> exp.Expression:
 487    """
 488    Use the subtraction and addition properties of equality to simplify expressions:
 489
 490        x + 1 = 3 becomes x = 2
 491
 492    There are two binary operations in the above expression: + and =
 493    Here's how we reference all the operands in the code below:
 494
 495          l     r
 496        x + 1 = 3
 497        a   b
 498    """
 499    if isinstance(expression, COMPARISONS):
 500        l, r = expression.left, expression.right
 501
 502        if l.__class__ not in INVERSE_OPS:
 503            return expression
 504
 505        if r.is_number:
 506            a_predicate = _is_number
 507            b_predicate = _is_number
 508        elif _is_date_literal(r):
 509            a_predicate = _is_date_literal
 510            b_predicate = _is_interval
 511        else:
 512            return expression
 513
 514        if l.__class__ in INVERSE_DATE_OPS:
 515            l = t.cast(exp.IntervalOp, l)
 516            a = l.this
 517            b = l.interval()
 518        else:
 519            l = t.cast(exp.Binary, l)
 520            a, b = l.left, l.right
 521
 522        if not a_predicate(a) and b_predicate(b):
 523            pass
 524        elif not a_predicate(b) and b_predicate(a):
 525            a, b = b, a
 526        else:
 527            return expression
 528
 529        return expression.__class__(
 530            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 531        )
 532    return expression
 533
 534
 535def simplify_literals(expression, root=True):
 536    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 537        return _flat_simplify(expression, _simplify_binary, root)
 538
 539    if isinstance(expression, exp.Neg):
 540        this = expression.this
 541        if this.is_number:
 542            value = this.name
 543            if value[0] == "-":
 544                return exp.Literal.number(value[1:])
 545            return exp.Literal.number(f"-{value}")
 546
 547    if type(expression) in INVERSE_DATE_OPS:
 548        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 549
 550    return expression
 551
 552
 553NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 554
 555
 556def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 557    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 558        this = _simplify_integer_cast(expr.this)
 559    else:
 560        this = expr.this
 561
 562    if isinstance(expr, exp.Cast) and this.is_int:
 563        num = int(this.name)
 564
 565        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 566        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 567        # engine-dependent
 568        if (
 569            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 570        ) or (
 571            UTINYINT_MIN <= num <= UTINYINT_MAX
 572            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 573        ):
 574            return this
 575
 576    return expr
 577
 578
 579def _simplify_binary(expression, a, b):
 580    if isinstance(expression, COMPARISONS):
 581        a = _simplify_integer_cast(a)
 582        b = _simplify_integer_cast(b)
 583
 584    if isinstance(expression, exp.Is):
 585        if isinstance(b, exp.Not):
 586            c = b.this
 587            not_ = True
 588        else:
 589            c = b
 590            not_ = False
 591
 592        if is_null(c):
 593            if isinstance(a, exp.Literal):
 594                return exp.true() if not_ else exp.false()
 595            if is_null(a):
 596                return exp.false() if not_ else exp.true()
 597    elif isinstance(expression, NULL_OK):
 598        return None
 599    elif is_null(a) or is_null(b):
 600        return exp.null()
 601
 602    if a.is_number and b.is_number:
 603        num_a = int(a.name) if a.is_int else Decimal(a.name)
 604        num_b = int(b.name) if b.is_int else Decimal(b.name)
 605
 606        if isinstance(expression, exp.Add):
 607            return exp.Literal.number(num_a + num_b)
 608        if isinstance(expression, exp.Mul):
 609            return exp.Literal.number(num_a * num_b)
 610
 611        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 612        if isinstance(expression, exp.Sub):
 613            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 614        if isinstance(expression, exp.Div):
 615            # engines have differing int div behavior so intdiv is not safe
 616            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 617                return None
 618            return exp.Literal.number(num_a / num_b)
 619
 620        boolean = eval_boolean(expression, num_a, num_b)
 621
 622        if boolean:
 623            return boolean
 624    elif a.is_string and b.is_string:
 625        boolean = eval_boolean(expression, a.this, b.this)
 626
 627        if boolean:
 628            return boolean
 629    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 630        a, b = extract_date(a), extract_interval(b)
 631        if a and b:
 632            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 633                return date_literal(a + b)
 634            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 635                return date_literal(a - b)
 636    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 637        a, b = extract_interval(a), extract_date(b)
 638        # you cannot subtract a date from an interval
 639        if a and b and isinstance(expression, exp.Add):
 640            return date_literal(a + b)
 641    elif _is_date_literal(a) and _is_date_literal(b):
 642        if isinstance(expression, exp.Predicate):
 643            a, b = extract_date(a), extract_date(b)
 644            boolean = eval_boolean(expression, a, b)
 645            if boolean:
 646                return boolean
 647
 648    return None
 649
 650
 651def simplify_parens(expression):
 652    if not isinstance(expression, exp.Paren):
 653        return expression
 654
 655    this = expression.this
 656    parent = expression.parent
 657    parent_is_predicate = isinstance(parent, exp.Predicate)
 658
 659    if not isinstance(this, exp.Select) and (
 660        not isinstance(parent, (exp.Condition, exp.Binary))
 661        or isinstance(parent, exp.Paren)
 662        or (
 663            not isinstance(this, exp.Binary)
 664            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 665        )
 666        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 667        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 668        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 669        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 670    ):
 671        return this
 672    return expression
 673
 674
 675def _is_nonnull_constant(expression: exp.Expression) -> bool:
 676    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 677
 678
 679def _is_constant(expression: exp.Expression) -> bool:
 680    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 681
 682
 683def simplify_coalesce(expression):
 684    # COALESCE(x) -> x
 685    if (
 686        isinstance(expression, exp.Coalesce)
 687        and (not expression.expressions or _is_nonnull_constant(expression.this))
 688        # COALESCE is also used as a Spark partitioning hint
 689        and not isinstance(expression.parent, exp.Hint)
 690    ):
 691        return expression.this
 692
 693    if not isinstance(expression, COMPARISONS):
 694        return expression
 695
 696    if isinstance(expression.left, exp.Coalesce):
 697        coalesce = expression.left
 698        other = expression.right
 699    elif isinstance(expression.right, exp.Coalesce):
 700        coalesce = expression.right
 701        other = expression.left
 702    else:
 703        return expression
 704
 705    # This transformation is valid for non-constants,
 706    # but it really only does anything if they are both constants.
 707    if not _is_constant(other):
 708        return expression
 709
 710    # Find the first constant arg
 711    for arg_index, arg in enumerate(coalesce.expressions):
 712        if _is_constant(arg):
 713            break
 714    else:
 715        return expression
 716
 717    coalesce.set("expressions", coalesce.expressions[:arg_index])
 718
 719    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 720    # since we already remove COALESCE at the top of this function.
 721    coalesce = coalesce if coalesce.expressions else coalesce.this
 722
 723    # This expression is more complex than when we started, but it will get simplified further
 724    return exp.paren(
 725        exp.or_(
 726            exp.and_(
 727                coalesce.is_(exp.null()).not_(copy=False),
 728                expression.copy(),
 729                copy=False,
 730            ),
 731            exp.and_(
 732                coalesce.is_(exp.null()),
 733                type(expression)(this=arg.copy(), expression=other.copy()),
 734                copy=False,
 735            ),
 736            copy=False,
 737        )
 738    )
 739
 740
 741CONCATS = (exp.Concat, exp.DPipe)
 742
 743
 744def simplify_concat(expression):
 745    """Reduces all groups that contain string literals by concatenating them."""
 746    if not isinstance(expression, CONCATS) or (
 747        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 748        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 749    ):
 750        return expression
 751
 752    if isinstance(expression, exp.ConcatWs):
 753        sep_expr, *expressions = expression.expressions
 754        sep = sep_expr.name
 755        concat_type = exp.ConcatWs
 756        args = {}
 757    else:
 758        expressions = expression.expressions
 759        sep = ""
 760        concat_type = exp.Concat
 761        args = {
 762            "safe": expression.args.get("safe"),
 763            "coalesce": expression.args.get("coalesce"),
 764        }
 765
 766    new_args = []
 767    for is_string_group, group in itertools.groupby(
 768        expressions or expression.flatten(), lambda e: e.is_string
 769    ):
 770        if is_string_group:
 771            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 772        else:
 773            new_args.extend(group)
 774
 775    if len(new_args) == 1 and new_args[0].is_string:
 776        return new_args[0]
 777
 778    if concat_type is exp.ConcatWs:
 779        new_args = [sep_expr] + new_args
 780
 781    return concat_type(expressions=new_args, **args)
 782
 783
 784def simplify_conditionals(expression):
 785    """Simplifies expressions like IF, CASE if their condition is statically known."""
 786    if isinstance(expression, exp.Case):
 787        this = expression.this
 788        for case in expression.args["ifs"]:
 789            cond = case.this
 790            if this:
 791                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 792                cond = cond.replace(this.pop().eq(cond))
 793
 794            if always_true(cond):
 795                return case.args["true"]
 796
 797            if always_false(cond):
 798                case.pop()
 799                if not expression.args["ifs"]:
 800                    return expression.args.get("default") or exp.null()
 801    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 802        if always_true(expression.this):
 803            return expression.args["true"]
 804        if always_false(expression.this):
 805            return expression.args.get("false") or exp.null()
 806
 807    return expression
 808
 809
 810def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 811    """
 812    Reduces a prefix check to either TRUE or FALSE if both the string and the
 813    prefix are statically known.
 814
 815    Example:
 816        >>> from sqlglot import parse_one
 817        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 818        'TRUE'
 819    """
 820    if (
 821        isinstance(expression, exp.StartsWith)
 822        and expression.this.is_string
 823        and expression.expression.is_string
 824    ):
 825        return exp.convert(expression.name.startswith(expression.expression.name))
 826
 827    return expression
 828
 829
 830DateRange = t.Tuple[datetime.date, datetime.date]
 831
 832
 833def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 834    """
 835    Get the date range for a DATE_TRUNC equality comparison:
 836
 837    Example:
 838        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 839    Returns:
 840        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 841    """
 842    floor = date_floor(date, unit, dialect)
 843
 844    if date != floor:
 845        # This will always be False, except for NULL values.
 846        return None
 847
 848    return floor, floor + interval(unit)
 849
 850
 851def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 852    """Get the logical expression for a date range"""
 853    return exp.and_(
 854        left >= date_literal(drange[0]),
 855        left < date_literal(drange[1]),
 856        copy=False,
 857    )
 858
 859
 860def _datetrunc_eq(
 861    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 862) -> t.Optional[exp.Expression]:
 863    drange = _datetrunc_range(date, unit, dialect)
 864    if not drange:
 865        return None
 866
 867    return _datetrunc_eq_expression(left, drange)
 868
 869
 870def _datetrunc_neq(
 871    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 872) -> t.Optional[exp.Expression]:
 873    drange = _datetrunc_range(date, unit, dialect)
 874    if not drange:
 875        return None
 876
 877    return exp.and_(
 878        left < date_literal(drange[0]),
 879        left >= date_literal(drange[1]),
 880        copy=False,
 881    )
 882
 883
 884DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 885    exp.LT: lambda l, dt, u, d: l
 886    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
 887    exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
 888    exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
 889    exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
 890    exp.EQ: _datetrunc_eq,
 891    exp.NEQ: _datetrunc_neq,
 892}
 893DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 894DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 895
 896
 897def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 898    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 899
 900
 901@catch(ModuleNotFoundError, UnsupportedUnit)
 902def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 903    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 904    comparison = expression.__class__
 905
 906    if isinstance(expression, DATETRUNCS):
 907        date = extract_date(expression.this)
 908        if date and expression.unit:
 909            return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
 910    elif comparison not in DATETRUNC_COMPARISONS:
 911        return expression
 912
 913    if isinstance(expression, exp.Binary):
 914        l, r = expression.left, expression.right
 915
 916        if not _is_datetrunc_predicate(l, r):
 917            return expression
 918
 919        l = t.cast(exp.DateTrunc, l)
 920        unit = l.unit.name.lower()
 921        date = extract_date(r)
 922
 923        if not date:
 924            return expression
 925
 926        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
 927    elif isinstance(expression, exp.In):
 928        l = expression.this
 929        rs = expression.expressions
 930
 931        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 932            l = t.cast(exp.DateTrunc, l)
 933            unit = l.unit.name.lower()
 934
 935            ranges = []
 936            for r in rs:
 937                date = extract_date(r)
 938                if not date:
 939                    return expression
 940                drange = _datetrunc_range(date, unit, dialect)
 941                if drange:
 942                    ranges.append(drange)
 943
 944            if not ranges:
 945                return expression
 946
 947            ranges = merge_ranges(ranges)
 948
 949            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 950
 951    return expression
 952
 953
 954def sort_comparison(expression: exp.Expression) -> exp.Expression:
 955    if expression.__class__ in COMPLEMENT_COMPARISONS:
 956        l, r = expression.this, expression.expression
 957        l_column = isinstance(l, exp.Column)
 958        r_column = isinstance(r, exp.Column)
 959        l_const = _is_constant(l)
 960        r_const = _is_constant(r)
 961
 962        if (l_column and not r_column) or (r_const and not l_const):
 963            return expression
 964        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 965            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 966                this=r, expression=l
 967            )
 968    return expression
 969
 970
 971# CROSS joins result in an empty table if the right table is empty.
 972# So we can only simplify certain types of joins to CROSS.
 973# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 974JOINS = {
 975    ("", ""),
 976    ("", "INNER"),
 977    ("RIGHT", ""),
 978    ("RIGHT", "OUTER"),
 979}
 980
 981
 982def remove_where_true(expression):
 983    for where in expression.find_all(exp.Where):
 984        if always_true(where.this):
 985            where.pop()
 986    for join in expression.find_all(exp.Join):
 987        if (
 988            always_true(join.args.get("on"))
 989            and not join.args.get("using")
 990            and not join.args.get("method")
 991            and (join.side, join.kind) in JOINS
 992        ):
 993            join.args["on"].pop()
 994            join.set("side", None)
 995            join.set("kind", "CROSS")
 996
 997
 998def always_true(expression):
 999    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1000        expression, exp.Literal
1001    )
1002
1003
1004def always_false(expression):
1005    return is_false(expression) or is_null(expression)
1006
1007
1008def is_complement(a, b):
1009    return isinstance(b, exp.Not) and b.this == a
1010
1011
1012def is_false(a: exp.Expression) -> bool:
1013    return type(a) is exp.Boolean and not a.this
1014
1015
1016def is_null(a: exp.Expression) -> bool:
1017    return type(a) is exp.Null
1018
1019
1020def eval_boolean(expression, a, b):
1021    if isinstance(expression, (exp.EQ, exp.Is)):
1022        return boolean_literal(a == b)
1023    if isinstance(expression, exp.NEQ):
1024        return boolean_literal(a != b)
1025    if isinstance(expression, exp.GT):
1026        return boolean_literal(a > b)
1027    if isinstance(expression, exp.GTE):
1028        return boolean_literal(a >= b)
1029    if isinstance(expression, exp.LT):
1030        return boolean_literal(a < b)
1031    if isinstance(expression, exp.LTE):
1032        return boolean_literal(a <= b)
1033    return None
1034
1035
1036def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1037    if isinstance(value, datetime.datetime):
1038        return value.date()
1039    if isinstance(value, datetime.date):
1040        return value
1041    try:
1042        return datetime.datetime.fromisoformat(value).date()
1043    except ValueError:
1044        return None
1045
1046
1047def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1048    if isinstance(value, datetime.datetime):
1049        return value
1050    if isinstance(value, datetime.date):
1051        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1052    try:
1053        return datetime.datetime.fromisoformat(value)
1054    except ValueError:
1055        return None
1056
1057
1058def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1059    if not value:
1060        return None
1061    if to.is_type(exp.DataType.Type.DATE):
1062        return cast_as_date(value)
1063    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1064        return cast_as_datetime(value)
1065    return None
1066
1067
1068def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1069    if isinstance(cast, exp.Cast):
1070        to = cast.to
1071    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1072        to = exp.DataType.build(exp.DataType.Type.DATE)
1073    else:
1074        return None
1075
1076    if isinstance(cast.this, exp.Literal):
1077        value: t.Any = cast.this.name
1078    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1079        value = extract_date(cast.this)
1080    else:
1081        return None
1082    return cast_value(value, to)
1083
1084
1085def _is_date_literal(expression: exp.Expression) -> bool:
1086    return extract_date(expression) is not None
1087
1088
1089def extract_interval(expression):
1090    try:
1091        n = int(expression.name)
1092        unit = expression.text("unit").lower()
1093        return interval(unit, n)
1094    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1095        return None
1096
1097
1098def date_literal(date):
1099    return exp.cast(
1100        exp.Literal.string(date),
1101        (
1102            exp.DataType.Type.DATETIME
1103            if isinstance(date, datetime.datetime)
1104            else exp.DataType.Type.DATE
1105        ),
1106    )
1107
1108
1109def interval(unit: str, n: int = 1):
1110    from dateutil.relativedelta import relativedelta
1111
1112    if unit == "year":
1113        return relativedelta(years=1 * n)
1114    if unit == "quarter":
1115        return relativedelta(months=3 * n)
1116    if unit == "month":
1117        return relativedelta(months=1 * n)
1118    if unit == "week":
1119        return relativedelta(weeks=1 * n)
1120    if unit == "day":
1121        return relativedelta(days=1 * n)
1122    if unit == "hour":
1123        return relativedelta(hours=1 * n)
1124    if unit == "minute":
1125        return relativedelta(minutes=1 * n)
1126    if unit == "second":
1127        return relativedelta(seconds=1 * n)
1128
1129    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1130
1131
1132def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1133    if unit == "year":
1134        return d.replace(month=1, day=1)
1135    if unit == "quarter":
1136        if d.month <= 3:
1137            return d.replace(month=1, day=1)
1138        elif d.month <= 6:
1139            return d.replace(month=4, day=1)
1140        elif d.month <= 9:
1141            return d.replace(month=7, day=1)
1142        else:
1143            return d.replace(month=10, day=1)
1144    if unit == "month":
1145        return d.replace(month=d.month, day=1)
1146    if unit == "week":
1147        # Assuming week starts on Monday (0) and ends on Sunday (6)
1148        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1149    if unit == "day":
1150        return d
1151
1152    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1153
1154
1155def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1156    floor = date_floor(d, unit, dialect)
1157
1158    if floor == d:
1159        return d
1160
1161    return floor + interval(unit)
1162
1163
1164def boolean_literal(condition):
1165    return exp.true() if condition else exp.false()
1166
1167
1168def _flat_simplify(expression, simplifier, root=True):
1169    if root or not expression.same_parent:
1170        operands = []
1171        queue = deque(expression.flatten(unnest=False))
1172        size = len(queue)
1173
1174        while queue:
1175            a = queue.popleft()
1176
1177            for b in queue:
1178                result = simplifier(expression, a, b)
1179
1180                if result and result is not expression:
1181                    queue.remove(b)
1182                    queue.appendleft(result)
1183                    break
1184            else:
1185                operands.append(a)
1186
1187        if len(operands) < size:
1188            return functools.reduce(
1189                lambda a, b: expression.__class__(this=a, expression=b), operands
1190            )
1191    return expression
1192
1193
1194def gen(expression: t.Any) -> str:
1195    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1196
1197    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1198    generator is expensive so we have a bare minimum sql generator here.
1199    """
1200    return Gen().gen(expression)
1201
1202
1203class Gen:
1204    def __init__(self):
1205        self.stack = []
1206        self.sqls = []
1207
1208    def gen(self, expression: exp.Expression) -> str:
1209        self.stack = [expression]
1210        self.sqls.clear()
1211
1212        while self.stack:
1213            node = self.stack.pop()
1214
1215            if isinstance(node, exp.Expression):
1216                exp_handler_name = f"{node.key}_sql"
1217
1218                if hasattr(self, exp_handler_name):
1219                    getattr(self, exp_handler_name)(node)
1220                elif isinstance(node, exp.Func):
1221                    self._function(node)
1222                else:
1223                    key = node.key.upper()
1224                    self.stack.append(f"{key} " if self._args(node) else key)
1225            elif type(node) is list:
1226                for n in reversed(node):
1227                    if n is not None:
1228                        self.stack.extend((n, ","))
1229                if node:
1230                    self.stack.pop()
1231            else:
1232                if node is not None:
1233                    self.sqls.append(str(node))
1234
1235        return "".join(self.sqls)
1236
1237    def add_sql(self, e: exp.Add) -> None:
1238        self._binary(e, " + ")
1239
1240    def alias_sql(self, e: exp.Alias) -> None:
1241        self.stack.extend(
1242            (
1243                e.args.get("alias"),
1244                " AS ",
1245                e.args.get("this"),
1246            )
1247        )
1248
1249    def and_sql(self, e: exp.And) -> None:
1250        self._binary(e, " AND ")
1251
1252    def anonymous_sql(self, e: exp.Anonymous) -> None:
1253        this = e.this
1254        if isinstance(this, str):
1255            name = this.upper()
1256        elif isinstance(this, exp.Identifier):
1257            name = this.this
1258            name = f'"{name}"' if this.quoted else name.upper()
1259        else:
1260            raise ValueError(
1261                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1262            )
1263
1264        self.stack.extend(
1265            (
1266                ")",
1267                e.expressions,
1268                "(",
1269                name,
1270            )
1271        )
1272
1273    def between_sql(self, e: exp.Between) -> None:
1274        self.stack.extend(
1275            (
1276                e.args.get("high"),
1277                " AND ",
1278                e.args.get("low"),
1279                " BETWEEN ",
1280                e.this,
1281            )
1282        )
1283
1284    def boolean_sql(self, e: exp.Boolean) -> None:
1285        self.stack.append("TRUE" if e.this else "FALSE")
1286
1287    def bracket_sql(self, e: exp.Bracket) -> None:
1288        self.stack.extend(
1289            (
1290                "]",
1291                e.expressions,
1292                "[",
1293                e.this,
1294            )
1295        )
1296
1297    def column_sql(self, e: exp.Column) -> None:
1298        for p in reversed(e.parts):
1299            self.stack.extend((p, "."))
1300        self.stack.pop()
1301
1302    def datatype_sql(self, e: exp.DataType) -> None:
1303        self._args(e, 1)
1304        self.stack.append(f"{e.this.name} ")
1305
1306    def div_sql(self, e: exp.Div) -> None:
1307        self._binary(e, " / ")
1308
1309    def dot_sql(self, e: exp.Dot) -> None:
1310        self._binary(e, ".")
1311
1312    def eq_sql(self, e: exp.EQ) -> None:
1313        self._binary(e, " = ")
1314
1315    def from_sql(self, e: exp.From) -> None:
1316        self.stack.extend((e.this, "FROM "))
1317
1318    def gt_sql(self, e: exp.GT) -> None:
1319        self._binary(e, " > ")
1320
1321    def gte_sql(self, e: exp.GTE) -> None:
1322        self._binary(e, " >= ")
1323
1324    def identifier_sql(self, e: exp.Identifier) -> None:
1325        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1326
1327    def ilike_sql(self, e: exp.ILike) -> None:
1328        self._binary(e, " ILIKE ")
1329
1330    def in_sql(self, e: exp.In) -> None:
1331        self.stack.append(")")
1332        self._args(e, 1)
1333        self.stack.extend(
1334            (
1335                "(",
1336                " IN ",
1337                e.this,
1338            )
1339        )
1340
1341    def intdiv_sql(self, e: exp.IntDiv) -> None:
1342        self._binary(e, " DIV ")
1343
1344    def is_sql(self, e: exp.Is) -> None:
1345        self._binary(e, " IS ")
1346
1347    def like_sql(self, e: exp.Like) -> None:
1348        self._binary(e, " Like ")
1349
1350    def literal_sql(self, e: exp.Literal) -> None:
1351        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1352
1353    def lt_sql(self, e: exp.LT) -> None:
1354        self._binary(e, " < ")
1355
1356    def lte_sql(self, e: exp.LTE) -> None:
1357        self._binary(e, " <= ")
1358
1359    def mod_sql(self, e: exp.Mod) -> None:
1360        self._binary(e, " % ")
1361
1362    def mul_sql(self, e: exp.Mul) -> None:
1363        self._binary(e, " * ")
1364
1365    def neg_sql(self, e: exp.Neg) -> None:
1366        self._unary(e, "-")
1367
1368    def neq_sql(self, e: exp.NEQ) -> None:
1369        self._binary(e, " <> ")
1370
1371    def not_sql(self, e: exp.Not) -> None:
1372        self._unary(e, "NOT ")
1373
1374    def null_sql(self, e: exp.Null) -> None:
1375        self.stack.append("NULL")
1376
1377    def or_sql(self, e: exp.Or) -> None:
1378        self._binary(e, " OR ")
1379
1380    def paren_sql(self, e: exp.Paren) -> None:
1381        self.stack.extend(
1382            (
1383                ")",
1384                e.this,
1385                "(",
1386            )
1387        )
1388
1389    def sub_sql(self, e: exp.Sub) -> None:
1390        self._binary(e, " - ")
1391
1392    def subquery_sql(self, e: exp.Subquery) -> None:
1393        self._args(e, 2)
1394        alias = e.args.get("alias")
1395        if alias:
1396            self.stack.append(alias)
1397        self.stack.extend((")", e.this, "("))
1398
1399    def table_sql(self, e: exp.Table) -> None:
1400        self._args(e, 4)
1401        alias = e.args.get("alias")
1402        if alias:
1403            self.stack.append(alias)
1404        for p in reversed(e.parts):
1405            self.stack.extend((p, "."))
1406        self.stack.pop()
1407
1408    def tablealias_sql(self, e: exp.TableAlias) -> None:
1409        columns = e.columns
1410
1411        if columns:
1412            self.stack.extend((")", columns, "("))
1413
1414        self.stack.extend((e.this, " AS "))
1415
1416    def var_sql(self, e: exp.Var) -> None:
1417        self.stack.append(e.this)
1418
1419    def _binary(self, e: exp.Binary, op: str) -> None:
1420        self.stack.extend((e.expression, op, e.this))
1421
1422    def _unary(self, e: exp.Unary, op: str) -> None:
1423        self.stack.extend((e.this, op))
1424
1425    def _function(self, e: exp.Func) -> None:
1426        self.stack.extend(
1427            (
1428                ")",
1429                list(e.args.values()),
1430                "(",
1431                e.sql_name(),
1432            )
1433        )
1434
1435    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1436        kvs = []
1437        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1438
1439        for k in arg_types or arg_types:
1440            v = node.args.get(k)
1441
1442            if v is not None:
1443                kvs.append([f":{k}", v])
1444        if kvs:
1445            self.stack.append(kvs)
1446            return True
1447        return False
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
33class UnsupportedUnit(Exception):
34    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 37def simplify(
 38    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 39):
 40    """
 41    Rewrite sqlglot AST to simplify expressions.
 42
 43    Example:
 44        >>> import sqlglot
 45        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 46        >>> simplify(expression).sql()
 47        'TRUE'
 48
 49    Args:
 50        expression (sqlglot.Expression): expression to simplify
 51        constant_propagation: whether the constant propagation rule should be used
 52
 53    Returns:
 54        sqlglot.Expression: simplified expression
 55    """
 56
 57    dialect = Dialect.get_or_raise(dialect)
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # group by expressions cannot be simplified, for example
 64        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 65        # the projection must exactly match the group by key
 66        group = expression.args.get("group")
 67
 68        if group and hasattr(expression, "selects"):
 69            groups = set(group.expressions)
 70            group.meta[FINAL] = True
 71
 72            for e in expression.selects:
 73                for node in e.walk():
 74                    if node in groups:
 75                        e.meta[FINAL] = True
 76                        break
 77
 78            having = expression.args.get("having")
 79            if having:
 80                for node in having.walk():
 81                    if node in groups:
 82                        having.meta[FINAL] = True
 83                        break
 84
 85        # Pre-order transformations
 86        node = expression
 87        node = rewrite_between(node)
 88        node = uniq_sort(node, root)
 89        node = absorb_and_eliminate(node, root)
 90        node = simplify_concat(node)
 91        node = simplify_conditionals(node)
 92
 93        if constant_propagation:
 94            node = propagate_constants(node, root)
 95
 96        exp.replace_children(node, lambda e: _simplify(e, False))
 97
 98        # Post-order transformations
 99        node = simplify_not(node)
100        node = flatten(node)
101        node = simplify_connectors(node, root)
102        node = remove_complements(node, root)
103        node = simplify_coalesce(node)
104        node.parent = expression.parent
105        node = simplify_literals(node, root)
106        node = simplify_equality(node)
107        node = simplify_parens(node)
108        node = simplify_datetrunc(node, dialect)
109        node = sort_comparison(node)
110        node = simplify_startswith(node)
111
112        if root:
113            expression.replace(node)
114        return node
115
116    expression = while_changing(expression, _simplify)
117    remove_where_true(expression)
118    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
121def catch(*exceptions):
122    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
123
124    def decorator(func):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression
130
131        return wrapped
132
133    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
136def rewrite_between(expression: exp.Expression) -> exp.Expression:
137    """Rewrite x between y and z to x >= y AND x <= z.
138
139    This is done because comparison simplification is only done on lt/lte/gt/gte.
140    """
141    if isinstance(expression, exp.Between):
142        negate = isinstance(expression.parent, exp.Not)
143
144        expression = exp.and_(
145            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
146            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
147            copy=False,
148        )
149
150        if negate:
151            expression = exp.paren(expression, copy=False)
152
153    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
166def simplify_not(expression):
167    """
168    Demorgan's Law
169    NOT (x OR y) -> NOT x AND NOT y
170    NOT (x AND y) -> NOT x OR NOT y
171    """
172    if isinstance(expression, exp.Not):
173        this = expression.this
174        if is_null(this):
175            return exp.null()
176        if this.__class__ in COMPLEMENT_COMPARISONS:
177            return COMPLEMENT_COMPARISONS[this.__class__](
178                this=this.this, expression=this.expression
179            )
180        if isinstance(this, exp.Paren):
181            condition = this.unnest()
182            if isinstance(condition, exp.And):
183                return exp.paren(
184                    exp.or_(
185                        exp.not_(condition.left, copy=False),
186                        exp.not_(condition.right, copy=False),
187                        copy=False,
188                    )
189                )
190            if isinstance(condition, exp.Or):
191                return exp.paren(
192                    exp.and_(
193                        exp.not_(condition.left, copy=False),
194                        exp.not_(condition.right, copy=False),
195                        copy=False,
196                    )
197                )
198            if is_null(condition):
199                return exp.null()
200        if always_true(this):
201            return exp.false()
202        if is_false(this):
203            return exp.true()
204        if isinstance(this, exp.Not):
205            # double negation
206            # NOT NOT x -> x
207            return this.this
208    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
211def flatten(expression):
212    """
213    A AND (B AND C) -> A AND B AND C
214    A OR (B OR C) -> A OR B OR C
215    """
216    if isinstance(expression, exp.Connector):
217        for node in expression.args.values():
218            child = node.unnest()
219            if isinstance(child, expression.__class__):
220                node.replace(child)
221    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
224def simplify_connectors(expression, root=True):
225    def _simplify_connectors(expression, left, right):
226        if left == right:
227            return left
228        if isinstance(expression, exp.And):
229            if is_false(left) or is_false(right):
230                return exp.false()
231            if is_null(left) or is_null(right):
232                return exp.null()
233            if always_true(left) and always_true(right):
234                return exp.true()
235            if always_true(left):
236                return right
237            if always_true(right):
238                return left
239            return _simplify_comparison(expression, left, right)
240        elif isinstance(expression, exp.Or):
241            if always_true(left) or always_true(right):
242                return exp.true()
243            if is_false(left) and is_false(right):
244                return exp.false()
245            if (
246                (is_null(left) and is_null(right))
247                or (is_null(left) and is_false(right))
248                or (is_false(left) and is_null(right))
249            ):
250                return exp.null()
251            if is_false(left):
252                return right
253            if is_false(right):
254                return left
255            return _simplify_comparison(expression, left, right, or_=True)
256
257    if isinstance(expression, exp.Connector):
258        return _flat_simplify(expression, _simplify_connectors, root)
259    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
343def remove_complements(expression, root=True):
344    """
345    Removing complements.
346
347    A AND NOT A -> FALSE
348    A OR NOT A -> TRUE
349    """
350    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
351        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
352
353        for a, b in itertools.permutations(expression.flatten(), 2):
354            if is_complement(a, b):
355                return complement
356    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
359def uniq_sort(expression, root=True):
360    """
361    Uniq and sort a connector.
362
363    C AND A AND B AND B -> A AND B AND C
364    """
365    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
366        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
367        flattened = tuple(expression.flatten())
368        deduped = {gen(e): e for e in flattened}
369        arr = tuple(deduped.items())
370
371        # check if the operands are already sorted, if not sort them
372        # A AND C AND B -> A AND B AND C
373        for i, (sql, e) in enumerate(arr[1:]):
374            if sql < arr[i][0]:
375                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
376                break
377        else:
378            # we didn't have to sort but maybe we need to dedup
379            if len(deduped) < len(flattened):
380                expression = result_func(*deduped.values(), copy=False)
381
382    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
385def absorb_and_eliminate(expression, root=True):
386    """
387    absorption:
388        A AND (A OR B) -> A
389        A OR (A AND B) -> A
390        A AND (NOT A OR B) -> A AND B
391        A OR (NOT A AND B) -> A OR B
392    elimination:
393        (A AND B) OR (A AND NOT B) -> A
394        (A OR B) AND (A OR NOT B) -> A
395    """
396    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
397        kind = exp.Or if isinstance(expression, exp.And) else exp.And
398
399        for a, b in itertools.permutations(expression.flatten(), 2):
400            if isinstance(a, kind):
401                aa, ab = a.unnest_operands()
402
403                # absorb
404                if is_complement(b, aa):
405                    aa.replace(exp.true() if kind == exp.And else exp.false())
406                elif is_complement(b, ab):
407                    ab.replace(exp.true() if kind == exp.And else exp.false())
408                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
409                    a.replace(exp.false() if kind == exp.And else exp.true())
410                elif isinstance(b, kind):
411                    # eliminate
412                    rhs = b.unnest_operands()
413                    ba, bb = rhs
414
415                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
416                        a.replace(aa)
417                        b.replace(aa)
418                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
419                        a.replace(ab)
420                        b.replace(ab)
421
422    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
425def propagate_constants(expression, root=True):
426    """
427    Propagate constants for conjunctions in DNF:
428
429    SELECT * FROM t WHERE a = b AND b = 5 becomes
430    SELECT * FROM t WHERE a = 5 AND b = 5
431
432    Reference: https://www.sqlite.org/optoverview.html
433    """
434
435    if (
436        isinstance(expression, exp.And)
437        and (root or not expression.same_parent)
438        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
439    ):
440        constant_mapping = {}
441        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
442            if isinstance(expr, exp.EQ):
443                l, r = expr.left, expr.right
444
445                # TODO: create a helper that can be used to detect nested literal expressions such
446                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
447                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
448                    constant_mapping[l] = (id(l), r)
449
450        if constant_mapping:
451            for column in find_all_in_scope(expression, exp.Column):
452                parent = column.parent
453                column_id, constant = constant_mapping.get(column) or (None, None)
454                if (
455                    column_id is not None
456                    and id(column) != column_id
457                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
458                ):
459                    column.replace(constant.copy())
460
461    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
536def simplify_literals(expression, root=True):
537    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
538        return _flat_simplify(expression, _simplify_binary, root)
539
540    if isinstance(expression, exp.Neg):
541        this = expression.this
542        if this.is_number:
543            value = this.name
544            if value[0] == "-":
545                return exp.Literal.number(value[1:])
546            return exp.Literal.number(f"-{value}")
547
548    if type(expression) in INVERSE_DATE_OPS:
549        return _simplify_binary(expression, expression.this, expression.interval()) or expression
550
551    return expression
def simplify_parens(expression):
652def simplify_parens(expression):
653    if not isinstance(expression, exp.Paren):
654        return expression
655
656    this = expression.this
657    parent = expression.parent
658    parent_is_predicate = isinstance(parent, exp.Predicate)
659
660    if not isinstance(this, exp.Select) and (
661        not isinstance(parent, (exp.Condition, exp.Binary))
662        or isinstance(parent, exp.Paren)
663        or (
664            not isinstance(this, exp.Binary)
665            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
666        )
667        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
668        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
669        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
670        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
671    ):
672        return this
673    return expression
def simplify_coalesce(expression):
684def simplify_coalesce(expression):
685    # COALESCE(x) -> x
686    if (
687        isinstance(expression, exp.Coalesce)
688        and (not expression.expressions or _is_nonnull_constant(expression.this))
689        # COALESCE is also used as a Spark partitioning hint
690        and not isinstance(expression.parent, exp.Hint)
691    ):
692        return expression.this
693
694    if not isinstance(expression, COMPARISONS):
695        return expression
696
697    if isinstance(expression.left, exp.Coalesce):
698        coalesce = expression.left
699        other = expression.right
700    elif isinstance(expression.right, exp.Coalesce):
701        coalesce = expression.right
702        other = expression.left
703    else:
704        return expression
705
706    # This transformation is valid for non-constants,
707    # but it really only does anything if they are both constants.
708    if not _is_constant(other):
709        return expression
710
711    # Find the first constant arg
712    for arg_index, arg in enumerate(coalesce.expressions):
713        if _is_constant(arg):
714            break
715    else:
716        return expression
717
718    coalesce.set("expressions", coalesce.expressions[:arg_index])
719
720    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
721    # since we already remove COALESCE at the top of this function.
722    coalesce = coalesce if coalesce.expressions else coalesce.this
723
724    # This expression is more complex than when we started, but it will get simplified further
725    return exp.paren(
726        exp.or_(
727            exp.and_(
728                coalesce.is_(exp.null()).not_(copy=False),
729                expression.copy(),
730                copy=False,
731            ),
732            exp.and_(
733                coalesce.is_(exp.null()),
734                type(expression)(this=arg.copy(), expression=other.copy()),
735                copy=False,
736            ),
737            copy=False,
738        )
739    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
745def simplify_concat(expression):
746    """Reduces all groups that contain string literals by concatenating them."""
747    if not isinstance(expression, CONCATS) or (
748        # We can't reduce a CONCAT_WS call if we don't statically know the separator
749        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
750    ):
751        return expression
752
753    if isinstance(expression, exp.ConcatWs):
754        sep_expr, *expressions = expression.expressions
755        sep = sep_expr.name
756        concat_type = exp.ConcatWs
757        args = {}
758    else:
759        expressions = expression.expressions
760        sep = ""
761        concat_type = exp.Concat
762        args = {
763            "safe": expression.args.get("safe"),
764            "coalesce": expression.args.get("coalesce"),
765        }
766
767    new_args = []
768    for is_string_group, group in itertools.groupby(
769        expressions or expression.flatten(), lambda e: e.is_string
770    ):
771        if is_string_group:
772            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
773        else:
774            new_args.extend(group)
775
776    if len(new_args) == 1 and new_args[0].is_string:
777        return new_args[0]
778
779    if concat_type is exp.ConcatWs:
780        new_args = [sep_expr] + new_args
781
782    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
785def simplify_conditionals(expression):
786    """Simplifies expressions like IF, CASE if their condition is statically known."""
787    if isinstance(expression, exp.Case):
788        this = expression.this
789        for case in expression.args["ifs"]:
790            cond = case.this
791            if this:
792                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
793                cond = cond.replace(this.pop().eq(cond))
794
795            if always_true(cond):
796                return case.args["true"]
797
798            if always_false(cond):
799                case.pop()
800                if not expression.args["ifs"]:
801                    return expression.args.get("default") or exp.null()
802    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
803        if always_true(expression.this):
804            return expression.args["true"]
805        if always_false(expression.this):
806            return expression.args.get("false") or exp.null()
807
808    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
811def simplify_startswith(expression: exp.Expression) -> exp.Expression:
812    """
813    Reduces a prefix check to either TRUE or FALSE if both the string and the
814    prefix are statically known.
815
816    Example:
817        >>> from sqlglot import parse_one
818        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
819        'TRUE'
820    """
821    if (
822        isinstance(expression, exp.StartsWith)
823        and expression.this.is_string
824        and expression.expression.is_string
825    ):
826        return exp.convert(expression.name.startswith(expression.expression.name))
827
828    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>}
def simplify_datetrunc(expression, *args, **kwargs):
125        def wrapped(expression, *args, **kwargs):
126            try:
127                return func(expression, *args, **kwargs)
128            except exceptions:
129                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
955def sort_comparison(expression: exp.Expression) -> exp.Expression:
956    if expression.__class__ in COMPLEMENT_COMPARISONS:
957        l, r = expression.this, expression.expression
958        l_column = isinstance(l, exp.Column)
959        r_column = isinstance(r, exp.Column)
960        l_const = _is_constant(l)
961        r_const = _is_constant(r)
962
963        if (l_column and not r_column) or (r_const and not l_const):
964            return expression
965        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
966            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
967                this=r, expression=l
968            )
969    return expression
JOINS = {('', ''), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER')}
def remove_where_true(expression):
983def remove_where_true(expression):
984    for where in expression.find_all(exp.Where):
985        if always_true(where.this):
986            where.pop()
987    for join in expression.find_all(exp.Join):
988        if (
989            always_true(join.args.get("on"))
990            and not join.args.get("using")
991            and not join.args.get("method")
992            and (join.side, join.kind) in JOINS
993        ):
994            join.args["on"].pop()
995            join.set("side", None)
996            join.set("kind", "CROSS")
def always_true(expression):
 999def always_true(expression):
1000    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1001        expression, exp.Literal
1002    )
def always_false(expression):
1005def always_false(expression):
1006    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1009def is_complement(a, b):
1010    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1013def is_false(a: exp.Expression) -> bool:
1014    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1017def is_null(a: exp.Expression) -> bool:
1018    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1021def eval_boolean(expression, a, b):
1022    if isinstance(expression, (exp.EQ, exp.Is)):
1023        return boolean_literal(a == b)
1024    if isinstance(expression, exp.NEQ):
1025        return boolean_literal(a != b)
1026    if isinstance(expression, exp.GT):
1027        return boolean_literal(a > b)
1028    if isinstance(expression, exp.GTE):
1029        return boolean_literal(a >= b)
1030    if isinstance(expression, exp.LT):
1031        return boolean_literal(a < b)
1032    if isinstance(expression, exp.LTE):
1033        return boolean_literal(a <= b)
1034    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1037def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1038    if isinstance(value, datetime.datetime):
1039        return value.date()
1040    if isinstance(value, datetime.date):
1041        return value
1042    try:
1043        return datetime.datetime.fromisoformat(value).date()
1044    except ValueError:
1045        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1048def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1049    if isinstance(value, datetime.datetime):
1050        return value
1051    if isinstance(value, datetime.date):
1052        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1053    try:
1054        return datetime.datetime.fromisoformat(value)
1055    except ValueError:
1056        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1059def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1060    if not value:
1061        return None
1062    if to.is_type(exp.DataType.Type.DATE):
1063        return cast_as_date(value)
1064    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1065        return cast_as_datetime(value)
1066    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1069def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1070    if isinstance(cast, exp.Cast):
1071        to = cast.to
1072    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1073        to = exp.DataType.build(exp.DataType.Type.DATE)
1074    else:
1075        return None
1076
1077    if isinstance(cast.this, exp.Literal):
1078        value: t.Any = cast.this.name
1079    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1080        value = extract_date(cast.this)
1081    else:
1082        return None
1083    return cast_value(value, to)
def extract_interval(expression):
1090def extract_interval(expression):
1091    try:
1092        n = int(expression.name)
1093        unit = expression.text("unit").lower()
1094        return interval(unit, n)
1095    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1096        return None
def date_literal(date):
1099def date_literal(date):
1100    return exp.cast(
1101        exp.Literal.string(date),
1102        (
1103            exp.DataType.Type.DATETIME
1104            if isinstance(date, datetime.datetime)
1105            else exp.DataType.Type.DATE
1106        ),
1107    )
def interval(unit: str, n: int = 1):
1110def interval(unit: str, n: int = 1):
1111    from dateutil.relativedelta import relativedelta
1112
1113    if unit == "year":
1114        return relativedelta(years=1 * n)
1115    if unit == "quarter":
1116        return relativedelta(months=3 * n)
1117    if unit == "month":
1118        return relativedelta(months=1 * n)
1119    if unit == "week":
1120        return relativedelta(weeks=1 * n)
1121    if unit == "day":
1122        return relativedelta(days=1 * n)
1123    if unit == "hour":
1124        return relativedelta(hours=1 * n)
1125    if unit == "minute":
1126        return relativedelta(minutes=1 * n)
1127    if unit == "second":
1128        return relativedelta(seconds=1 * n)
1129
1130    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1133def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1134    if unit == "year":
1135        return d.replace(month=1, day=1)
1136    if unit == "quarter":
1137        if d.month <= 3:
1138            return d.replace(month=1, day=1)
1139        elif d.month <= 6:
1140            return d.replace(month=4, day=1)
1141        elif d.month <= 9:
1142            return d.replace(month=7, day=1)
1143        else:
1144            return d.replace(month=10, day=1)
1145    if unit == "month":
1146        return d.replace(month=d.month, day=1)
1147    if unit == "week":
1148        # Assuming week starts on Monday (0) and ends on Sunday (6)
1149        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1150    if unit == "day":
1151        return d
1152
1153    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1156def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1157    floor = date_floor(d, unit, dialect)
1158
1159    if floor == d:
1160        return d
1161
1162    return floor + interval(unit)
def boolean_literal(condition):
1165def boolean_literal(condition):
1166    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1195def gen(expression: t.Any) -> str:
1196    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1197
1198    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1199    generator is expensive so we have a bare minimum sql generator here.
1200    """
1201    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1204class Gen:
1205    def __init__(self):
1206        self.stack = []
1207        self.sqls = []
1208
1209    def gen(self, expression: exp.Expression) -> str:
1210        self.stack = [expression]
1211        self.sqls.clear()
1212
1213        while self.stack:
1214            node = self.stack.pop()
1215
1216            if isinstance(node, exp.Expression):
1217                exp_handler_name = f"{node.key}_sql"
1218
1219                if hasattr(self, exp_handler_name):
1220                    getattr(self, exp_handler_name)(node)
1221                elif isinstance(node, exp.Func):
1222                    self._function(node)
1223                else:
1224                    key = node.key.upper()
1225                    self.stack.append(f"{key} " if self._args(node) else key)
1226            elif type(node) is list:
1227                for n in reversed(node):
1228                    if n is not None:
1229                        self.stack.extend((n, ","))
1230                if node:
1231                    self.stack.pop()
1232            else:
1233                if node is not None:
1234                    self.sqls.append(str(node))
1235
1236        return "".join(self.sqls)
1237
1238    def add_sql(self, e: exp.Add) -> None:
1239        self._binary(e, " + ")
1240
1241    def alias_sql(self, e: exp.Alias) -> None:
1242        self.stack.extend(
1243            (
1244                e.args.get("alias"),
1245                " AS ",
1246                e.args.get("this"),
1247            )
1248        )
1249
1250    def and_sql(self, e: exp.And) -> None:
1251        self._binary(e, " AND ")
1252
1253    def anonymous_sql(self, e: exp.Anonymous) -> None:
1254        this = e.this
1255        if isinstance(this, str):
1256            name = this.upper()
1257        elif isinstance(this, exp.Identifier):
1258            name = this.this
1259            name = f'"{name}"' if this.quoted else name.upper()
1260        else:
1261            raise ValueError(
1262                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1263            )
1264
1265        self.stack.extend(
1266            (
1267                ")",
1268                e.expressions,
1269                "(",
1270                name,
1271            )
1272        )
1273
1274    def between_sql(self, e: exp.Between) -> None:
1275        self.stack.extend(
1276            (
1277                e.args.get("high"),
1278                " AND ",
1279                e.args.get("low"),
1280                " BETWEEN ",
1281                e.this,
1282            )
1283        )
1284
1285    def boolean_sql(self, e: exp.Boolean) -> None:
1286        self.stack.append("TRUE" if e.this else "FALSE")
1287
1288    def bracket_sql(self, e: exp.Bracket) -> None:
1289        self.stack.extend(
1290            (
1291                "]",
1292                e.expressions,
1293                "[",
1294                e.this,
1295            )
1296        )
1297
1298    def column_sql(self, e: exp.Column) -> None:
1299        for p in reversed(e.parts):
1300            self.stack.extend((p, "."))
1301        self.stack.pop()
1302
1303    def datatype_sql(self, e: exp.DataType) -> None:
1304        self._args(e, 1)
1305        self.stack.append(f"{e.this.name} ")
1306
1307    def div_sql(self, e: exp.Div) -> None:
1308        self._binary(e, " / ")
1309
1310    def dot_sql(self, e: exp.Dot) -> None:
1311        self._binary(e, ".")
1312
1313    def eq_sql(self, e: exp.EQ) -> None:
1314        self._binary(e, " = ")
1315
1316    def from_sql(self, e: exp.From) -> None:
1317        self.stack.extend((e.this, "FROM "))
1318
1319    def gt_sql(self, e: exp.GT) -> None:
1320        self._binary(e, " > ")
1321
1322    def gte_sql(self, e: exp.GTE) -> None:
1323        self._binary(e, " >= ")
1324
1325    def identifier_sql(self, e: exp.Identifier) -> None:
1326        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1327
1328    def ilike_sql(self, e: exp.ILike) -> None:
1329        self._binary(e, " ILIKE ")
1330
1331    def in_sql(self, e: exp.In) -> None:
1332        self.stack.append(")")
1333        self._args(e, 1)
1334        self.stack.extend(
1335            (
1336                "(",
1337                " IN ",
1338                e.this,
1339            )
1340        )
1341
1342    def intdiv_sql(self, e: exp.IntDiv) -> None:
1343        self._binary(e, " DIV ")
1344
1345    def is_sql(self, e: exp.Is) -> None:
1346        self._binary(e, " IS ")
1347
1348    def like_sql(self, e: exp.Like) -> None:
1349        self._binary(e, " Like ")
1350
1351    def literal_sql(self, e: exp.Literal) -> None:
1352        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1353
1354    def lt_sql(self, e: exp.LT) -> None:
1355        self._binary(e, " < ")
1356
1357    def lte_sql(self, e: exp.LTE) -> None:
1358        self._binary(e, " <= ")
1359
1360    def mod_sql(self, e: exp.Mod) -> None:
1361        self._binary(e, " % ")
1362
1363    def mul_sql(self, e: exp.Mul) -> None:
1364        self._binary(e, " * ")
1365
1366    def neg_sql(self, e: exp.Neg) -> None:
1367        self._unary(e, "-")
1368
1369    def neq_sql(self, e: exp.NEQ) -> None:
1370        self._binary(e, " <> ")
1371
1372    def not_sql(self, e: exp.Not) -> None:
1373        self._unary(e, "NOT ")
1374
1375    def null_sql(self, e: exp.Null) -> None:
1376        self.stack.append("NULL")
1377
1378    def or_sql(self, e: exp.Or) -> None:
1379        self._binary(e, " OR ")
1380
1381    def paren_sql(self, e: exp.Paren) -> None:
1382        self.stack.extend(
1383            (
1384                ")",
1385                e.this,
1386                "(",
1387            )
1388        )
1389
1390    def sub_sql(self, e: exp.Sub) -> None:
1391        self._binary(e, " - ")
1392
1393    def subquery_sql(self, e: exp.Subquery) -> None:
1394        self._args(e, 2)
1395        alias = e.args.get("alias")
1396        if alias:
1397            self.stack.append(alias)
1398        self.stack.extend((")", e.this, "("))
1399
1400    def table_sql(self, e: exp.Table) -> None:
1401        self._args(e, 4)
1402        alias = e.args.get("alias")
1403        if alias:
1404            self.stack.append(alias)
1405        for p in reversed(e.parts):
1406            self.stack.extend((p, "."))
1407        self.stack.pop()
1408
1409    def tablealias_sql(self, e: exp.TableAlias) -> None:
1410        columns = e.columns
1411
1412        if columns:
1413            self.stack.extend((")", columns, "("))
1414
1415        self.stack.extend((e.this, " AS "))
1416
1417    def var_sql(self, e: exp.Var) -> None:
1418        self.stack.append(e.this)
1419
1420    def _binary(self, e: exp.Binary, op: str) -> None:
1421        self.stack.extend((e.expression, op, e.this))
1422
1423    def _unary(self, e: exp.Unary, op: str) -> None:
1424        self.stack.extend((e.this, op))
1425
1426    def _function(self, e: exp.Func) -> None:
1427        self.stack.extend(
1428            (
1429                ")",
1430                list(e.args.values()),
1431                "(",
1432                e.sql_name(),
1433            )
1434        )
1435
1436    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1437        kvs = []
1438        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1439
1440        for k in arg_types or arg_types:
1441            v = node.args.get(k)
1442
1443            if v is not None:
1444                kvs.append([f":{k}", v])
1445        if kvs:
1446            self.stack.append(kvs)
1447            return True
1448        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1209    def gen(self, expression: exp.Expression) -> str:
1210        self.stack = [expression]
1211        self.sqls.clear()
1212
1213        while self.stack:
1214            node = self.stack.pop()
1215
1216            if isinstance(node, exp.Expression):
1217                exp_handler_name = f"{node.key}_sql"
1218
1219                if hasattr(self, exp_handler_name):
1220                    getattr(self, exp_handler_name)(node)
1221                elif isinstance(node, exp.Func):
1222                    self._function(node)
1223                else:
1224                    key = node.key.upper()
1225                    self.stack.append(f"{key} " if self._args(node) else key)
1226            elif type(node) is list:
1227                for n in reversed(node):
1228                    if n is not None:
1229                        self.stack.extend((n, ","))
1230                if node:
1231                    self.stack.pop()
1232            else:
1233                if node is not None:
1234                    self.sqls.append(str(node))
1235
1236        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1238    def add_sql(self, e: exp.Add) -> None:
1239        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1241    def alias_sql(self, e: exp.Alias) -> None:
1242        self.stack.extend(
1243            (
1244                e.args.get("alias"),
1245                " AS ",
1246                e.args.get("this"),
1247            )
1248        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1250    def and_sql(self, e: exp.And) -> None:
1251        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1253    def anonymous_sql(self, e: exp.Anonymous) -> None:
1254        this = e.this
1255        if isinstance(this, str):
1256            name = this.upper()
1257        elif isinstance(this, exp.Identifier):
1258            name = this.this
1259            name = f'"{name}"' if this.quoted else name.upper()
1260        else:
1261            raise ValueError(
1262                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1263            )
1264
1265        self.stack.extend(
1266            (
1267                ")",
1268                e.expressions,
1269                "(",
1270                name,
1271            )
1272        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1274    def between_sql(self, e: exp.Between) -> None:
1275        self.stack.extend(
1276            (
1277                e.args.get("high"),
1278                " AND ",
1279                e.args.get("low"),
1280                " BETWEEN ",
1281                e.this,
1282            )
1283        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1285    def boolean_sql(self, e: exp.Boolean) -> None:
1286        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1288    def bracket_sql(self, e: exp.Bracket) -> None:
1289        self.stack.extend(
1290            (
1291                "]",
1292                e.expressions,
1293                "[",
1294                e.this,
1295            )
1296        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1298    def column_sql(self, e: exp.Column) -> None:
1299        for p in reversed(e.parts):
1300            self.stack.extend((p, "."))
1301        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1303    def datatype_sql(self, e: exp.DataType) -> None:
1304        self._args(e, 1)
1305        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1307    def div_sql(self, e: exp.Div) -> None:
1308        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1310    def dot_sql(self, e: exp.Dot) -> None:
1311        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1313    def eq_sql(self, e: exp.EQ) -> None:
1314        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1316    def from_sql(self, e: exp.From) -> None:
1317        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1319    def gt_sql(self, e: exp.GT) -> None:
1320        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1322    def gte_sql(self, e: exp.GTE) -> None:
1323        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1325    def identifier_sql(self, e: exp.Identifier) -> None:
1326        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1328    def ilike_sql(self, e: exp.ILike) -> None:
1329        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1331    def in_sql(self, e: exp.In) -> None:
1332        self.stack.append(")")
1333        self._args(e, 1)
1334        self.stack.extend(
1335            (
1336                "(",
1337                " IN ",
1338                e.this,
1339            )
1340        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1342    def intdiv_sql(self, e: exp.IntDiv) -> None:
1343        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1345    def is_sql(self, e: exp.Is) -> None:
1346        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1348    def like_sql(self, e: exp.Like) -> None:
1349        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1351    def literal_sql(self, e: exp.Literal) -> None:
1352        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1354    def lt_sql(self, e: exp.LT) -> None:
1355        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1357    def lte_sql(self, e: exp.LTE) -> None:
1358        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1360    def mod_sql(self, e: exp.Mod) -> None:
1361        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1363    def mul_sql(self, e: exp.Mul) -> None:
1364        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1366    def neg_sql(self, e: exp.Neg) -> None:
1367        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1369    def neq_sql(self, e: exp.NEQ) -> None:
1370        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1372    def not_sql(self, e: exp.Not) -> None:
1373        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1375    def null_sql(self, e: exp.Null) -> None:
1376        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1378    def or_sql(self, e: exp.Or) -> None:
1379        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1381    def paren_sql(self, e: exp.Paren) -> None:
1382        self.stack.extend(
1383            (
1384                ")",
1385                e.this,
1386                "(",
1387            )
1388        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1390    def sub_sql(self, e: exp.Sub) -> None:
1391        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1393    def subquery_sql(self, e: exp.Subquery) -> None:
1394        self._args(e, 2)
1395        alias = e.args.get("alias")
1396        if alias:
1397            self.stack.append(alias)
1398        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1400    def table_sql(self, e: exp.Table) -> None:
1401        self._args(e, 4)
1402        alias = e.args.get("alias")
1403        if alias:
1404            self.stack.append(alias)
1405        for p in reversed(e.parts):
1406            self.stack.extend((p, "."))
1407        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1409    def tablealias_sql(self, e: exp.TableAlias) -> None:
1410        columns = e.columns
1411
1412        if columns:
1413            self.stack.extend((")", columns, "("))
1414
1415        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1417    def var_sql(self, e: exp.Var) -> None:
1418        self.stack.append(e.this)