Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25
  26class UnsupportedUnit(Exception):
  27    pass
  28
  29
  30def simplify(
  31    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  32):
  33    """
  34    Rewrite sqlglot AST to simplify expressions.
  35
  36    Example:
  37        >>> import sqlglot
  38        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  39        >>> simplify(expression).sql()
  40        'TRUE'
  41
  42    Args:
  43        expression (sqlglot.Expression): expression to simplify
  44        constant_propagation: whether the constant propagation rule should be used
  45
  46    Returns:
  47        sqlglot.Expression: simplified expression
  48    """
  49
  50    dialect = Dialect.get_or_raise(dialect)
  51
  52    def _simplify(expression, root=True):
  53        if expression.meta.get(FINAL):
  54            return expression
  55
  56        # group by expressions cannot be simplified, for example
  57        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  58        # the projection must exactly match the group by key
  59        group = expression.args.get("group")
  60
  61        if group and hasattr(expression, "selects"):
  62            groups = set(group.expressions)
  63            group.meta[FINAL] = True
  64
  65            for e in expression.selects:
  66                for node in e.walk():
  67                    if node in groups:
  68                        e.meta[FINAL] = True
  69                        break
  70
  71            having = expression.args.get("having")
  72            if having:
  73                for node in having.walk():
  74                    if node in groups:
  75                        having.meta[FINAL] = True
  76                        break
  77
  78        # Pre-order transformations
  79        node = expression
  80        node = rewrite_between(node)
  81        node = uniq_sort(node, root)
  82        node = absorb_and_eliminate(node, root)
  83        node = simplify_concat(node)
  84        node = simplify_conditionals(node)
  85
  86        if constant_propagation:
  87            node = propagate_constants(node, root)
  88
  89        exp.replace_children(node, lambda e: _simplify(e, False))
  90
  91        # Post-order transformations
  92        node = simplify_not(node)
  93        node = flatten(node)
  94        node = simplify_connectors(node, root)
  95        node = remove_complements(node, root)
  96        node = simplify_coalesce(node)
  97        node.parent = expression.parent
  98        node = simplify_literals(node, root)
  99        node = simplify_equality(node)
 100        node = simplify_parens(node)
 101        node = simplify_datetrunc(node, dialect)
 102        node = sort_comparison(node)
 103        node = simplify_startswith(node)
 104
 105        if root:
 106            expression.replace(node)
 107        return node
 108
 109    expression = while_changing(expression, _simplify)
 110    remove_where_true(expression)
 111    return expression
 112
 113
 114def catch(*exceptions):
 115    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 116
 117    def decorator(func):
 118        def wrapped(expression, *args, **kwargs):
 119            try:
 120                return func(expression, *args, **kwargs)
 121            except exceptions:
 122                return expression
 123
 124        return wrapped
 125
 126    return decorator
 127
 128
 129def rewrite_between(expression: exp.Expression) -> exp.Expression:
 130    """Rewrite x between y and z to x >= y AND x <= z.
 131
 132    This is done because comparison simplification is only done on lt/lte/gt/gte.
 133    """
 134    if isinstance(expression, exp.Between):
 135        negate = isinstance(expression.parent, exp.Not)
 136
 137        expression = exp.and_(
 138            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 139            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 140            copy=False,
 141        )
 142
 143        if negate:
 144            expression = exp.paren(expression, copy=False)
 145
 146    return expression
 147
 148
 149COMPLEMENT_COMPARISONS = {
 150    exp.LT: exp.GTE,
 151    exp.GT: exp.LTE,
 152    exp.LTE: exp.GT,
 153    exp.GTE: exp.LT,
 154    exp.EQ: exp.NEQ,
 155    exp.NEQ: exp.EQ,
 156}
 157
 158
 159def simplify_not(expression):
 160    """
 161    Demorgan's Law
 162    NOT (x OR y) -> NOT x AND NOT y
 163    NOT (x AND y) -> NOT x OR NOT y
 164    """
 165    if isinstance(expression, exp.Not):
 166        this = expression.this
 167        if is_null(this):
 168            return exp.null()
 169        if this.__class__ in COMPLEMENT_COMPARISONS:
 170            return COMPLEMENT_COMPARISONS[this.__class__](
 171                this=this.this, expression=this.expression
 172            )
 173        if isinstance(this, exp.Paren):
 174            condition = this.unnest()
 175            if isinstance(condition, exp.And):
 176                return exp.paren(
 177                    exp.or_(
 178                        exp.not_(condition.left, copy=False),
 179                        exp.not_(condition.right, copy=False),
 180                        copy=False,
 181                    )
 182                )
 183            if isinstance(condition, exp.Or):
 184                return exp.paren(
 185                    exp.and_(
 186                        exp.not_(condition.left, copy=False),
 187                        exp.not_(condition.right, copy=False),
 188                        copy=False,
 189                    )
 190                )
 191            if is_null(condition):
 192                return exp.null()
 193        if always_true(this):
 194            return exp.false()
 195        if is_false(this):
 196            return exp.true()
 197        if isinstance(this, exp.Not):
 198            # double negation
 199            # NOT NOT x -> x
 200            return this.this
 201    return expression
 202
 203
 204def flatten(expression):
 205    """
 206    A AND (B AND C) -> A AND B AND C
 207    A OR (B OR C) -> A OR B OR C
 208    """
 209    if isinstance(expression, exp.Connector):
 210        for node in expression.args.values():
 211            child = node.unnest()
 212            if isinstance(child, expression.__class__):
 213                node.replace(child)
 214    return expression
 215
 216
 217def simplify_connectors(expression, root=True):
 218    def _simplify_connectors(expression, left, right):
 219        if left == right:
 220            return left
 221        if isinstance(expression, exp.And):
 222            if is_false(left) or is_false(right):
 223                return exp.false()
 224            if is_null(left) or is_null(right):
 225                return exp.null()
 226            if always_true(left) and always_true(right):
 227                return exp.true()
 228            if always_true(left):
 229                return right
 230            if always_true(right):
 231                return left
 232            return _simplify_comparison(expression, left, right)
 233        elif isinstance(expression, exp.Or):
 234            if always_true(left) or always_true(right):
 235                return exp.true()
 236            if is_false(left) and is_false(right):
 237                return exp.false()
 238            if (
 239                (is_null(left) and is_null(right))
 240                or (is_null(left) and is_false(right))
 241                or (is_false(left) and is_null(right))
 242            ):
 243                return exp.null()
 244            if is_false(left):
 245                return right
 246            if is_false(right):
 247                return left
 248            return _simplify_comparison(expression, left, right, or_=True)
 249
 250    if isinstance(expression, exp.Connector):
 251        return _flat_simplify(expression, _simplify_connectors, root)
 252    return expression
 253
 254
 255LT_LTE = (exp.LT, exp.LTE)
 256GT_GTE = (exp.GT, exp.GTE)
 257
 258COMPARISONS = (
 259    *LT_LTE,
 260    *GT_GTE,
 261    exp.EQ,
 262    exp.NEQ,
 263    exp.Is,
 264)
 265
 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 267    exp.LT: exp.GT,
 268    exp.GT: exp.LT,
 269    exp.LTE: exp.GTE,
 270    exp.GTE: exp.LTE,
 271}
 272
 273NONDETERMINISTIC = (exp.Rand, exp.Randn)
 274
 275
 276def _simplify_comparison(expression, left, right, or_=False):
 277    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 278        ll, lr = left.args.values()
 279        rl, rr = right.args.values()
 280
 281        largs = {ll, lr}
 282        rargs = {rl, rr}
 283
 284        matching = largs & rargs
 285        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 286
 287        if matching and columns:
 288            try:
 289                l = first(largs - columns)
 290                r = first(rargs - columns)
 291            except StopIteration:
 292                return expression
 293
 294            if l.is_number and r.is_number:
 295                l = float(l.name)
 296                r = float(r.name)
 297            elif l.is_string and r.is_string:
 298                l = l.name
 299                r = r.name
 300            else:
 301                l = extract_date(l)
 302                if not l:
 303                    return None
 304                r = extract_date(r)
 305                if not r:
 306                    return None
 307
 308            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 309                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 310                    return left if (av > bv if or_ else av <= bv) else right
 311                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 312                    return left if (av < bv if or_ else av >= bv) else right
 313
 314                # we can't ever shortcut to true because the column could be null
 315                if not or_:
 316                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 317                        if av <= bv:
 318                            return exp.false()
 319                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 320                        if av >= bv:
 321                            return exp.false()
 322                    elif isinstance(a, exp.EQ):
 323                        if isinstance(b, exp.LT):
 324                            return exp.false() if av >= bv else a
 325                        if isinstance(b, exp.LTE):
 326                            return exp.false() if av > bv else a
 327                        if isinstance(b, exp.GT):
 328                            return exp.false() if av <= bv else a
 329                        if isinstance(b, exp.GTE):
 330                            return exp.false() if av < bv else a
 331                        if isinstance(b, exp.NEQ):
 332                            return exp.false() if av == bv else a
 333    return None
 334
 335
 336def remove_complements(expression, root=True):
 337    """
 338    Removing complements.
 339
 340    A AND NOT A -> FALSE
 341    A OR NOT A -> TRUE
 342    """
 343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 344        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 345
 346        for a, b in itertools.permutations(expression.flatten(), 2):
 347            if is_complement(a, b):
 348                return complement
 349    return expression
 350
 351
 352def uniq_sort(expression, root=True):
 353    """
 354    Uniq and sort a connector.
 355
 356    C AND A AND B AND B -> A AND B AND C
 357    """
 358    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 359        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 360        flattened = tuple(expression.flatten())
 361        deduped = {gen(e): e for e in flattened}
 362        arr = tuple(deduped.items())
 363
 364        # check if the operands are already sorted, if not sort them
 365        # A AND C AND B -> A AND B AND C
 366        for i, (sql, e) in enumerate(arr[1:]):
 367            if sql < arr[i][0]:
 368                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 369                break
 370        else:
 371            # we didn't have to sort but maybe we need to dedup
 372            if len(deduped) < len(flattened):
 373                expression = result_func(*deduped.values(), copy=False)
 374
 375    return expression
 376
 377
 378def absorb_and_eliminate(expression, root=True):
 379    """
 380    absorption:
 381        A AND (A OR B) -> A
 382        A OR (A AND B) -> A
 383        A AND (NOT A OR B) -> A AND B
 384        A OR (NOT A AND B) -> A OR B
 385    elimination:
 386        (A AND B) OR (A AND NOT B) -> A
 387        (A OR B) AND (A OR NOT B) -> A
 388    """
 389    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 390        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 391
 392        for a, b in itertools.permutations(expression.flatten(), 2):
 393            if isinstance(a, kind):
 394                aa, ab = a.unnest_operands()
 395
 396                # absorb
 397                if is_complement(b, aa):
 398                    aa.replace(exp.true() if kind == exp.And else exp.false())
 399                elif is_complement(b, ab):
 400                    ab.replace(exp.true() if kind == exp.And else exp.false())
 401                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 402                    a.replace(exp.false() if kind == exp.And else exp.true())
 403                elif isinstance(b, kind):
 404                    # eliminate
 405                    rhs = b.unnest_operands()
 406                    ba, bb = rhs
 407
 408                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 409                        a.replace(aa)
 410                        b.replace(aa)
 411                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 412                        a.replace(ab)
 413                        b.replace(ab)
 414
 415    return expression
 416
 417
 418def propagate_constants(expression, root=True):
 419    """
 420    Propagate constants for conjunctions in DNF:
 421
 422    SELECT * FROM t WHERE a = b AND b = 5 becomes
 423    SELECT * FROM t WHERE a = 5 AND b = 5
 424
 425    Reference: https://www.sqlite.org/optoverview.html
 426    """
 427
 428    if (
 429        isinstance(expression, exp.And)
 430        and (root or not expression.same_parent)
 431        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 432    ):
 433        constant_mapping = {}
 434        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 435            if isinstance(expr, exp.EQ):
 436                l, r = expr.left, expr.right
 437
 438                # TODO: create a helper that can be used to detect nested literal expressions such
 439                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 440                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 441                    constant_mapping[l] = (id(l), r)
 442
 443        if constant_mapping:
 444            for column in find_all_in_scope(expression, exp.Column):
 445                parent = column.parent
 446                column_id, constant = constant_mapping.get(column) or (None, None)
 447                if (
 448                    column_id is not None
 449                    and id(column) != column_id
 450                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 451                ):
 452                    column.replace(constant.copy())
 453
 454    return expression
 455
 456
 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 458    exp.DateAdd: exp.Sub,
 459    exp.DateSub: exp.Add,
 460    exp.DatetimeAdd: exp.Sub,
 461    exp.DatetimeSub: exp.Add,
 462}
 463
 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 465    **INVERSE_DATE_OPS,
 466    exp.Add: exp.Sub,
 467    exp.Sub: exp.Add,
 468}
 469
 470
 471def _is_number(expression: exp.Expression) -> bool:
 472    return expression.is_number
 473
 474
 475def _is_interval(expression: exp.Expression) -> bool:
 476    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 477
 478
 479@catch(ModuleNotFoundError, UnsupportedUnit)
 480def simplify_equality(expression: exp.Expression) -> exp.Expression:
 481    """
 482    Use the subtraction and addition properties of equality to simplify expressions:
 483
 484        x + 1 = 3 becomes x = 2
 485
 486    There are two binary operations in the above expression: + and =
 487    Here's how we reference all the operands in the code below:
 488
 489          l     r
 490        x + 1 = 3
 491        a   b
 492    """
 493    if isinstance(expression, COMPARISONS):
 494        l, r = expression.left, expression.right
 495
 496        if l.__class__ not in INVERSE_OPS:
 497            return expression
 498
 499        if r.is_number:
 500            a_predicate = _is_number
 501            b_predicate = _is_number
 502        elif _is_date_literal(r):
 503            a_predicate = _is_date_literal
 504            b_predicate = _is_interval
 505        else:
 506            return expression
 507
 508        if l.__class__ in INVERSE_DATE_OPS:
 509            l = t.cast(exp.IntervalOp, l)
 510            a = l.this
 511            b = l.interval()
 512        else:
 513            l = t.cast(exp.Binary, l)
 514            a, b = l.left, l.right
 515
 516        if not a_predicate(a) and b_predicate(b):
 517            pass
 518        elif not a_predicate(b) and b_predicate(a):
 519            a, b = b, a
 520        else:
 521            return expression
 522
 523        return expression.__class__(
 524            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 525        )
 526    return expression
 527
 528
 529def simplify_literals(expression, root=True):
 530    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 531        return _flat_simplify(expression, _simplify_binary, root)
 532
 533    if isinstance(expression, exp.Neg):
 534        this = expression.this
 535        if this.is_number:
 536            value = this.name
 537            if value[0] == "-":
 538                return exp.Literal.number(value[1:])
 539            return exp.Literal.number(f"-{value}")
 540
 541    if type(expression) in INVERSE_DATE_OPS:
 542        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 543
 544    return expression
 545
 546
 547NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 548
 549
 550def _simplify_binary(expression, a, b):
 551    if isinstance(expression, exp.Is):
 552        if isinstance(b, exp.Not):
 553            c = b.this
 554            not_ = True
 555        else:
 556            c = b
 557            not_ = False
 558
 559        if is_null(c):
 560            if isinstance(a, exp.Literal):
 561                return exp.true() if not_ else exp.false()
 562            if is_null(a):
 563                return exp.false() if not_ else exp.true()
 564    elif isinstance(expression, NULL_OK):
 565        return None
 566    elif is_null(a) or is_null(b):
 567        return exp.null()
 568
 569    if a.is_number and b.is_number:
 570        num_a = int(a.name) if a.is_int else Decimal(a.name)
 571        num_b = int(b.name) if b.is_int else Decimal(b.name)
 572
 573        if isinstance(expression, exp.Add):
 574            return exp.Literal.number(num_a + num_b)
 575        if isinstance(expression, exp.Mul):
 576            return exp.Literal.number(num_a * num_b)
 577
 578        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 579        if isinstance(expression, exp.Sub):
 580            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 581        if isinstance(expression, exp.Div):
 582            # engines have differing int div behavior so intdiv is not safe
 583            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 584                return None
 585            return exp.Literal.number(num_a / num_b)
 586
 587        boolean = eval_boolean(expression, num_a, num_b)
 588
 589        if boolean:
 590            return boolean
 591    elif a.is_string and b.is_string:
 592        boolean = eval_boolean(expression, a.this, b.this)
 593
 594        if boolean:
 595            return boolean
 596    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 597        a, b = extract_date(a), extract_interval(b)
 598        if a and b:
 599            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 600                return date_literal(a + b)
 601            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 602                return date_literal(a - b)
 603    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 604        a, b = extract_interval(a), extract_date(b)
 605        # you cannot subtract a date from an interval
 606        if a and b and isinstance(expression, exp.Add):
 607            return date_literal(a + b)
 608    elif _is_date_literal(a) and _is_date_literal(b):
 609        if isinstance(expression, exp.Predicate):
 610            a, b = extract_date(a), extract_date(b)
 611            boolean = eval_boolean(expression, a, b)
 612            if boolean:
 613                return boolean
 614
 615    return None
 616
 617
 618def simplify_parens(expression):
 619    if not isinstance(expression, exp.Paren):
 620        return expression
 621
 622    this = expression.this
 623    parent = expression.parent
 624
 625    if not isinstance(this, exp.Select) and (
 626        not isinstance(parent, (exp.Condition, exp.Binary))
 627        or isinstance(parent, exp.Paren)
 628        or not isinstance(this, exp.Binary)
 629        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 630        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 631        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 632        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 633    ):
 634        return this
 635    return expression
 636
 637
 638NONNULL_CONSTANTS = (
 639    exp.Literal,
 640    exp.Boolean,
 641)
 642
 643CONSTANTS = (
 644    exp.Literal,
 645    exp.Boolean,
 646    exp.Null,
 647)
 648
 649
 650def _is_nonnull_constant(expression: exp.Expression) -> bool:
 651    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 652
 653
 654def _is_constant(expression: exp.Expression) -> bool:
 655    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 656
 657
 658def simplify_coalesce(expression):
 659    # COALESCE(x) -> x
 660    if (
 661        isinstance(expression, exp.Coalesce)
 662        and (not expression.expressions or _is_nonnull_constant(expression.this))
 663        # COALESCE is also used as a Spark partitioning hint
 664        and not isinstance(expression.parent, exp.Hint)
 665    ):
 666        return expression.this
 667
 668    if not isinstance(expression, COMPARISONS):
 669        return expression
 670
 671    if isinstance(expression.left, exp.Coalesce):
 672        coalesce = expression.left
 673        other = expression.right
 674    elif isinstance(expression.right, exp.Coalesce):
 675        coalesce = expression.right
 676        other = expression.left
 677    else:
 678        return expression
 679
 680    # This transformation is valid for non-constants,
 681    # but it really only does anything if they are both constants.
 682    if not _is_constant(other):
 683        return expression
 684
 685    # Find the first constant arg
 686    for arg_index, arg in enumerate(coalesce.expressions):
 687        if _is_constant(arg):
 688            break
 689    else:
 690        return expression
 691
 692    coalesce.set("expressions", coalesce.expressions[:arg_index])
 693
 694    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 695    # since we already remove COALESCE at the top of this function.
 696    coalesce = coalesce if coalesce.expressions else coalesce.this
 697
 698    # This expression is more complex than when we started, but it will get simplified further
 699    return exp.paren(
 700        exp.or_(
 701            exp.and_(
 702                coalesce.is_(exp.null()).not_(copy=False),
 703                expression.copy(),
 704                copy=False,
 705            ),
 706            exp.and_(
 707                coalesce.is_(exp.null()),
 708                type(expression)(this=arg.copy(), expression=other.copy()),
 709                copy=False,
 710            ),
 711            copy=False,
 712        )
 713    )
 714
 715
 716CONCATS = (exp.Concat, exp.DPipe)
 717
 718
 719def simplify_concat(expression):
 720    """Reduces all groups that contain string literals by concatenating them."""
 721    if not isinstance(expression, CONCATS) or (
 722        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 723        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 724    ):
 725        return expression
 726
 727    if isinstance(expression, exp.ConcatWs):
 728        sep_expr, *expressions = expression.expressions
 729        sep = sep_expr.name
 730        concat_type = exp.ConcatWs
 731        args = {}
 732    else:
 733        expressions = expression.expressions
 734        sep = ""
 735        concat_type = exp.Concat
 736        args = {
 737            "safe": expression.args.get("safe"),
 738            "coalesce": expression.args.get("coalesce"),
 739        }
 740
 741    new_args = []
 742    for is_string_group, group in itertools.groupby(
 743        expressions or expression.flatten(), lambda e: e.is_string
 744    ):
 745        if is_string_group:
 746            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 747        else:
 748            new_args.extend(group)
 749
 750    if len(new_args) == 1 and new_args[0].is_string:
 751        return new_args[0]
 752
 753    if concat_type is exp.ConcatWs:
 754        new_args = [sep_expr] + new_args
 755
 756    return concat_type(expressions=new_args, **args)
 757
 758
 759def simplify_conditionals(expression):
 760    """Simplifies expressions like IF, CASE if their condition is statically known."""
 761    if isinstance(expression, exp.Case):
 762        this = expression.this
 763        for case in expression.args["ifs"]:
 764            cond = case.this
 765            if this:
 766                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 767                cond = cond.replace(this.pop().eq(cond))
 768
 769            if always_true(cond):
 770                return case.args["true"]
 771
 772            if always_false(cond):
 773                case.pop()
 774                if not expression.args["ifs"]:
 775                    return expression.args.get("default") or exp.null()
 776    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 777        if always_true(expression.this):
 778            return expression.args["true"]
 779        if always_false(expression.this):
 780            return expression.args.get("false") or exp.null()
 781
 782    return expression
 783
 784
 785def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 786    """
 787    Reduces a prefix check to either TRUE or FALSE if both the string and the
 788    prefix are statically known.
 789
 790    Example:
 791        >>> from sqlglot import parse_one
 792        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 793        'TRUE'
 794    """
 795    if (
 796        isinstance(expression, exp.StartsWith)
 797        and expression.this.is_string
 798        and expression.expression.is_string
 799    ):
 800        return exp.convert(expression.name.startswith(expression.expression.name))
 801
 802    return expression
 803
 804
 805DateRange = t.Tuple[datetime.date, datetime.date]
 806
 807
 808def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 809    """
 810    Get the date range for a DATE_TRUNC equality comparison:
 811
 812    Example:
 813        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 814    Returns:
 815        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 816    """
 817    floor = date_floor(date, unit, dialect)
 818
 819    if date != floor:
 820        # This will always be False, except for NULL values.
 821        return None
 822
 823    return floor, floor + interval(unit)
 824
 825
 826def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 827    """Get the logical expression for a date range"""
 828    return exp.and_(
 829        left >= date_literal(drange[0]),
 830        left < date_literal(drange[1]),
 831        copy=False,
 832    )
 833
 834
 835def _datetrunc_eq(
 836    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 837) -> t.Optional[exp.Expression]:
 838    drange = _datetrunc_range(date, unit, dialect)
 839    if not drange:
 840        return None
 841
 842    return _datetrunc_eq_expression(left, drange)
 843
 844
 845def _datetrunc_neq(
 846    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 847) -> t.Optional[exp.Expression]:
 848    drange = _datetrunc_range(date, unit, dialect)
 849    if not drange:
 850        return None
 851
 852    return exp.and_(
 853        left < date_literal(drange[0]),
 854        left >= date_literal(drange[1]),
 855        copy=False,
 856    )
 857
 858
 859DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 860    exp.LT: lambda l, dt, u, d: l
 861    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
 862    exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
 863    exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
 864    exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
 865    exp.EQ: _datetrunc_eq,
 866    exp.NEQ: _datetrunc_neq,
 867}
 868DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 869DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 870
 871
 872def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 873    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 874
 875
 876@catch(ModuleNotFoundError, UnsupportedUnit)
 877def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 878    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 879    comparison = expression.__class__
 880
 881    if isinstance(expression, DATETRUNCS):
 882        date = extract_date(expression.this)
 883        if date and expression.unit:
 884            return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
 885    elif comparison not in DATETRUNC_COMPARISONS:
 886        return expression
 887
 888    if isinstance(expression, exp.Binary):
 889        l, r = expression.left, expression.right
 890
 891        if not _is_datetrunc_predicate(l, r):
 892            return expression
 893
 894        l = t.cast(exp.DateTrunc, l)
 895        unit = l.unit.name.lower()
 896        date = extract_date(r)
 897
 898        if not date:
 899            return expression
 900
 901        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
 902    elif isinstance(expression, exp.In):
 903        l = expression.this
 904        rs = expression.expressions
 905
 906        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 907            l = t.cast(exp.DateTrunc, l)
 908            unit = l.unit.name.lower()
 909
 910            ranges = []
 911            for r in rs:
 912                date = extract_date(r)
 913                if not date:
 914                    return expression
 915                drange = _datetrunc_range(date, unit, dialect)
 916                if drange:
 917                    ranges.append(drange)
 918
 919            if not ranges:
 920                return expression
 921
 922            ranges = merge_ranges(ranges)
 923
 924            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 925
 926    return expression
 927
 928
 929def sort_comparison(expression: exp.Expression) -> exp.Expression:
 930    if expression.__class__ in COMPLEMENT_COMPARISONS:
 931        l, r = expression.this, expression.expression
 932        l_column = isinstance(l, exp.Column)
 933        r_column = isinstance(r, exp.Column)
 934        l_const = _is_constant(l)
 935        r_const = _is_constant(r)
 936
 937        if (l_column and not r_column) or (r_const and not l_const):
 938            return expression
 939        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 940            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 941                this=r, expression=l
 942            )
 943    return expression
 944
 945
 946# CROSS joins result in an empty table if the right table is empty.
 947# So we can only simplify certain types of joins to CROSS.
 948# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 949JOINS = {
 950    ("", ""),
 951    ("", "INNER"),
 952    ("RIGHT", ""),
 953    ("RIGHT", "OUTER"),
 954}
 955
 956
 957def remove_where_true(expression):
 958    for where in expression.find_all(exp.Where):
 959        if always_true(where.this):
 960            where.parent.set("where", None)
 961    for join in expression.find_all(exp.Join):
 962        if (
 963            always_true(join.args.get("on"))
 964            and not join.args.get("using")
 965            and not join.args.get("method")
 966            and (join.side, join.kind) in JOINS
 967        ):
 968            join.set("on", None)
 969            join.set("side", None)
 970            join.set("kind", "CROSS")
 971
 972
 973def always_true(expression):
 974    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 975        expression, exp.Literal
 976    )
 977
 978
 979def always_false(expression):
 980    return is_false(expression) or is_null(expression)
 981
 982
 983def is_complement(a, b):
 984    return isinstance(b, exp.Not) and b.this == a
 985
 986
 987def is_false(a: exp.Expression) -> bool:
 988    return type(a) is exp.Boolean and not a.this
 989
 990
 991def is_null(a: exp.Expression) -> bool:
 992    return type(a) is exp.Null
 993
 994
 995def eval_boolean(expression, a, b):
 996    if isinstance(expression, (exp.EQ, exp.Is)):
 997        return boolean_literal(a == b)
 998    if isinstance(expression, exp.NEQ):
 999        return boolean_literal(a != b)
1000    if isinstance(expression, exp.GT):
1001        return boolean_literal(a > b)
1002    if isinstance(expression, exp.GTE):
1003        return boolean_literal(a >= b)
1004    if isinstance(expression, exp.LT):
1005        return boolean_literal(a < b)
1006    if isinstance(expression, exp.LTE):
1007        return boolean_literal(a <= b)
1008    return None
1009
1010
1011def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1012    if isinstance(value, datetime.datetime):
1013        return value.date()
1014    if isinstance(value, datetime.date):
1015        return value
1016    try:
1017        return datetime.datetime.fromisoformat(value).date()
1018    except ValueError:
1019        return None
1020
1021
1022def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1023    if isinstance(value, datetime.datetime):
1024        return value
1025    if isinstance(value, datetime.date):
1026        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1027    try:
1028        return datetime.datetime.fromisoformat(value)
1029    except ValueError:
1030        return None
1031
1032
1033def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1034    if not value:
1035        return None
1036    if to.is_type(exp.DataType.Type.DATE):
1037        return cast_as_date(value)
1038    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1039        return cast_as_datetime(value)
1040    return None
1041
1042
1043def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1044    if isinstance(cast, exp.Cast):
1045        to = cast.to
1046    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1047        to = exp.DataType.build(exp.DataType.Type.DATE)
1048    else:
1049        return None
1050
1051    if isinstance(cast.this, exp.Literal):
1052        value: t.Any = cast.this.name
1053    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1054        value = extract_date(cast.this)
1055    else:
1056        return None
1057    return cast_value(value, to)
1058
1059
1060def _is_date_literal(expression: exp.Expression) -> bool:
1061    return extract_date(expression) is not None
1062
1063
1064def extract_interval(expression):
1065    try:
1066        n = int(expression.name)
1067        unit = expression.text("unit").lower()
1068        return interval(unit, n)
1069    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1070        return None
1071
1072
1073def date_literal(date):
1074    return exp.cast(
1075        exp.Literal.string(date),
1076        (
1077            exp.DataType.Type.DATETIME
1078            if isinstance(date, datetime.datetime)
1079            else exp.DataType.Type.DATE
1080        ),
1081    )
1082
1083
1084def interval(unit: str, n: int = 1):
1085    from dateutil.relativedelta import relativedelta
1086
1087    if unit == "year":
1088        return relativedelta(years=1 * n)
1089    if unit == "quarter":
1090        return relativedelta(months=3 * n)
1091    if unit == "month":
1092        return relativedelta(months=1 * n)
1093    if unit == "week":
1094        return relativedelta(weeks=1 * n)
1095    if unit == "day":
1096        return relativedelta(days=1 * n)
1097    if unit == "hour":
1098        return relativedelta(hours=1 * n)
1099    if unit == "minute":
1100        return relativedelta(minutes=1 * n)
1101    if unit == "second":
1102        return relativedelta(seconds=1 * n)
1103
1104    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1105
1106
1107def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1108    if unit == "year":
1109        return d.replace(month=1, day=1)
1110    if unit == "quarter":
1111        if d.month <= 3:
1112            return d.replace(month=1, day=1)
1113        elif d.month <= 6:
1114            return d.replace(month=4, day=1)
1115        elif d.month <= 9:
1116            return d.replace(month=7, day=1)
1117        else:
1118            return d.replace(month=10, day=1)
1119    if unit == "month":
1120        return d.replace(month=d.month, day=1)
1121    if unit == "week":
1122        # Assuming week starts on Monday (0) and ends on Sunday (6)
1123        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1124    if unit == "day":
1125        return d
1126
1127    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1128
1129
1130def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1131    floor = date_floor(d, unit, dialect)
1132
1133    if floor == d:
1134        return d
1135
1136    return floor + interval(unit)
1137
1138
1139def boolean_literal(condition):
1140    return exp.true() if condition else exp.false()
1141
1142
1143def _flat_simplify(expression, simplifier, root=True):
1144    if root or not expression.same_parent:
1145        operands = []
1146        queue = deque(expression.flatten(unnest=False))
1147        size = len(queue)
1148
1149        while queue:
1150            a = queue.popleft()
1151
1152            for b in queue:
1153                result = simplifier(expression, a, b)
1154
1155                if result and result is not expression:
1156                    queue.remove(b)
1157                    queue.appendleft(result)
1158                    break
1159            else:
1160                operands.append(a)
1161
1162        if len(operands) < size:
1163            return functools.reduce(
1164                lambda a, b: expression.__class__(this=a, expression=b), operands
1165            )
1166    return expression
1167
1168
1169def gen(expression: t.Any) -> str:
1170    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1171
1172    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1173    generator is expensive so we have a bare minimum sql generator here.
1174    """
1175    return Gen().gen(expression)
1176
1177
1178class Gen:
1179    def __init__(self):
1180        self.stack = []
1181        self.sqls = []
1182
1183    def gen(self, expression: exp.Expression) -> str:
1184        self.stack = [expression]
1185        self.sqls.clear()
1186
1187        while self.stack:
1188            node = self.stack.pop()
1189
1190            if isinstance(node, exp.Expression):
1191                exp_handler_name = f"{node.key}_sql"
1192
1193                if hasattr(self, exp_handler_name):
1194                    getattr(self, exp_handler_name)(node)
1195                elif isinstance(node, exp.Func):
1196                    self._function(node)
1197                else:
1198                    key = node.key.upper()
1199                    self.stack.append(f"{key} " if self._args(node) else key)
1200            elif type(node) is list:
1201                for n in reversed(node):
1202                    if n is not None:
1203                        self.stack.extend((n, ","))
1204                if node:
1205                    self.stack.pop()
1206            else:
1207                if node is not None:
1208                    self.sqls.append(str(node))
1209
1210        return "".join(self.sqls)
1211
1212    def add_sql(self, e: exp.Add) -> None:
1213        self._binary(e, " + ")
1214
1215    def alias_sql(self, e: exp.Alias) -> None:
1216        self.stack.extend(
1217            (
1218                e.args.get("alias"),
1219                " AS ",
1220                e.args.get("this"),
1221            )
1222        )
1223
1224    def and_sql(self, e: exp.And) -> None:
1225        self._binary(e, " AND ")
1226
1227    def anonymous_sql(self, e: exp.Anonymous) -> None:
1228        this = e.this
1229        if isinstance(this, str):
1230            name = this.upper()
1231        elif isinstance(this, exp.Identifier):
1232            name = this.this
1233            name = f'"{name}"' if this.quoted else name.upper()
1234        else:
1235            raise ValueError(
1236                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1237            )
1238
1239        self.stack.extend(
1240            (
1241                ")",
1242                e.expressions,
1243                "(",
1244                name,
1245            )
1246        )
1247
1248    def between_sql(self, e: exp.Between) -> None:
1249        self.stack.extend(
1250            (
1251                e.args.get("high"),
1252                " AND ",
1253                e.args.get("low"),
1254                " BETWEEN ",
1255                e.this,
1256            )
1257        )
1258
1259    def boolean_sql(self, e: exp.Boolean) -> None:
1260        self.stack.append("TRUE" if e.this else "FALSE")
1261
1262    def bracket_sql(self, e: exp.Bracket) -> None:
1263        self.stack.extend(
1264            (
1265                "]",
1266                e.expressions,
1267                "[",
1268                e.this,
1269            )
1270        )
1271
1272    def column_sql(self, e: exp.Column) -> None:
1273        for p in reversed(e.parts):
1274            self.stack.extend((p, "."))
1275        self.stack.pop()
1276
1277    def datatype_sql(self, e: exp.DataType) -> None:
1278        self._args(e, 1)
1279        self.stack.append(f"{e.this.name} ")
1280
1281    def div_sql(self, e: exp.Div) -> None:
1282        self._binary(e, " / ")
1283
1284    def dot_sql(self, e: exp.Dot) -> None:
1285        self._binary(e, ".")
1286
1287    def eq_sql(self, e: exp.EQ) -> None:
1288        self._binary(e, " = ")
1289
1290    def from_sql(self, e: exp.From) -> None:
1291        self.stack.extend((e.this, "FROM "))
1292
1293    def gt_sql(self, e: exp.GT) -> None:
1294        self._binary(e, " > ")
1295
1296    def gte_sql(self, e: exp.GTE) -> None:
1297        self._binary(e, " >= ")
1298
1299    def identifier_sql(self, e: exp.Identifier) -> None:
1300        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1301
1302    def ilike_sql(self, e: exp.ILike) -> None:
1303        self._binary(e, " ILIKE ")
1304
1305    def in_sql(self, e: exp.In) -> None:
1306        self.stack.append(")")
1307        self._args(e, 1)
1308        self.stack.extend(
1309            (
1310                "(",
1311                " IN ",
1312                e.this,
1313            )
1314        )
1315
1316    def intdiv_sql(self, e: exp.IntDiv) -> None:
1317        self._binary(e, " DIV ")
1318
1319    def is_sql(self, e: exp.Is) -> None:
1320        self._binary(e, " IS ")
1321
1322    def like_sql(self, e: exp.Like) -> None:
1323        self._binary(e, " Like ")
1324
1325    def literal_sql(self, e: exp.Literal) -> None:
1326        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1327
1328    def lt_sql(self, e: exp.LT) -> None:
1329        self._binary(e, " < ")
1330
1331    def lte_sql(self, e: exp.LTE) -> None:
1332        self._binary(e, " <= ")
1333
1334    def mod_sql(self, e: exp.Mod) -> None:
1335        self._binary(e, " % ")
1336
1337    def mul_sql(self, e: exp.Mul) -> None:
1338        self._binary(e, " * ")
1339
1340    def neg_sql(self, e: exp.Neg) -> None:
1341        self._unary(e, "-")
1342
1343    def neq_sql(self, e: exp.NEQ) -> None:
1344        self._binary(e, " <> ")
1345
1346    def not_sql(self, e: exp.Not) -> None:
1347        self._unary(e, "NOT ")
1348
1349    def null_sql(self, e: exp.Null) -> None:
1350        self.stack.append("NULL")
1351
1352    def or_sql(self, e: exp.Or) -> None:
1353        self._binary(e, " OR ")
1354
1355    def paren_sql(self, e: exp.Paren) -> None:
1356        self.stack.extend(
1357            (
1358                ")",
1359                e.this,
1360                "(",
1361            )
1362        )
1363
1364    def sub_sql(self, e: exp.Sub) -> None:
1365        self._binary(e, " - ")
1366
1367    def subquery_sql(self, e: exp.Subquery) -> None:
1368        self._args(e, 2)
1369        alias = e.args.get("alias")
1370        if alias:
1371            self.stack.append(alias)
1372        self.stack.extend((")", e.this, "("))
1373
1374    def table_sql(self, e: exp.Table) -> None:
1375        self._args(e, 4)
1376        alias = e.args.get("alias")
1377        if alias:
1378            self.stack.append(alias)
1379        for p in reversed(e.parts):
1380            self.stack.extend((p, "."))
1381        self.stack.pop()
1382
1383    def tablealias_sql(self, e: exp.TableAlias) -> None:
1384        columns = e.columns
1385
1386        if columns:
1387            self.stack.extend((")", columns, "("))
1388
1389        self.stack.extend((e.this, " AS "))
1390
1391    def var_sql(self, e: exp.Var) -> None:
1392        self.stack.append(e.this)
1393
1394    def _binary(self, e: exp.Binary, op: str) -> None:
1395        self.stack.extend((e.expression, op, e.this))
1396
1397    def _unary(self, e: exp.Unary, op: str) -> None:
1398        self.stack.extend((e.this, op))
1399
1400    def _function(self, e: exp.Func) -> None:
1401        self.stack.extend(
1402            (
1403                ")",
1404                list(e.args.values()),
1405                "(",
1406                e.sql_name(),
1407            )
1408        )
1409
1410    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1411        kvs = []
1412        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1413
1414        for k in arg_types or arg_types:
1415            v = node.args.get(k)
1416
1417            if v is not None:
1418                kvs.append([f":{k}", v])
1419        if kvs:
1420            self.stack.append(kvs)
1421            return True
1422        return False
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
27class UnsupportedUnit(Exception):
28    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 31def simplify(
 32    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 33):
 34    """
 35    Rewrite sqlglot AST to simplify expressions.
 36
 37    Example:
 38        >>> import sqlglot
 39        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 40        >>> simplify(expression).sql()
 41        'TRUE'
 42
 43    Args:
 44        expression (sqlglot.Expression): expression to simplify
 45        constant_propagation: whether the constant propagation rule should be used
 46
 47    Returns:
 48        sqlglot.Expression: simplified expression
 49    """
 50
 51    dialect = Dialect.get_or_raise(dialect)
 52
 53    def _simplify(expression, root=True):
 54        if expression.meta.get(FINAL):
 55            return expression
 56
 57        # group by expressions cannot be simplified, for example
 58        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 59        # the projection must exactly match the group by key
 60        group = expression.args.get("group")
 61
 62        if group and hasattr(expression, "selects"):
 63            groups = set(group.expressions)
 64            group.meta[FINAL] = True
 65
 66            for e in expression.selects:
 67                for node in e.walk():
 68                    if node in groups:
 69                        e.meta[FINAL] = True
 70                        break
 71
 72            having = expression.args.get("having")
 73            if having:
 74                for node in having.walk():
 75                    if node in groups:
 76                        having.meta[FINAL] = True
 77                        break
 78
 79        # Pre-order transformations
 80        node = expression
 81        node = rewrite_between(node)
 82        node = uniq_sort(node, root)
 83        node = absorb_and_eliminate(node, root)
 84        node = simplify_concat(node)
 85        node = simplify_conditionals(node)
 86
 87        if constant_propagation:
 88            node = propagate_constants(node, root)
 89
 90        exp.replace_children(node, lambda e: _simplify(e, False))
 91
 92        # Post-order transformations
 93        node = simplify_not(node)
 94        node = flatten(node)
 95        node = simplify_connectors(node, root)
 96        node = remove_complements(node, root)
 97        node = simplify_coalesce(node)
 98        node.parent = expression.parent
 99        node = simplify_literals(node, root)
100        node = simplify_equality(node)
101        node = simplify_parens(node)
102        node = simplify_datetrunc(node, dialect)
103        node = sort_comparison(node)
104        node = simplify_startswith(node)
105
106        if root:
107            expression.replace(node)
108        return node
109
110    expression = while_changing(expression, _simplify)
111    remove_where_true(expression)
112    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
115def catch(*exceptions):
116    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
117
118    def decorator(func):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
124
125        return wrapped
126
127    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
130def rewrite_between(expression: exp.Expression) -> exp.Expression:
131    """Rewrite x between y and z to x >= y AND x <= z.
132
133    This is done because comparison simplification is only done on lt/lte/gt/gte.
134    """
135    if isinstance(expression, exp.Between):
136        negate = isinstance(expression.parent, exp.Not)
137
138        expression = exp.and_(
139            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
140            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
141            copy=False,
142        )
143
144        if negate:
145            expression = exp.paren(expression, copy=False)
146
147    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
160def simplify_not(expression):
161    """
162    Demorgan's Law
163    NOT (x OR y) -> NOT x AND NOT y
164    NOT (x AND y) -> NOT x OR NOT y
165    """
166    if isinstance(expression, exp.Not):
167        this = expression.this
168        if is_null(this):
169            return exp.null()
170        if this.__class__ in COMPLEMENT_COMPARISONS:
171            return COMPLEMENT_COMPARISONS[this.__class__](
172                this=this.this, expression=this.expression
173            )
174        if isinstance(this, exp.Paren):
175            condition = this.unnest()
176            if isinstance(condition, exp.And):
177                return exp.paren(
178                    exp.or_(
179                        exp.not_(condition.left, copy=False),
180                        exp.not_(condition.right, copy=False),
181                        copy=False,
182                    )
183                )
184            if isinstance(condition, exp.Or):
185                return exp.paren(
186                    exp.and_(
187                        exp.not_(condition.left, copy=False),
188                        exp.not_(condition.right, copy=False),
189                        copy=False,
190                    )
191                )
192            if is_null(condition):
193                return exp.null()
194        if always_true(this):
195            return exp.false()
196        if is_false(this):
197            return exp.true()
198        if isinstance(this, exp.Not):
199            # double negation
200            # NOT NOT x -> x
201            return this.this
202    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
205def flatten(expression):
206    """
207    A AND (B AND C) -> A AND B AND C
208    A OR (B OR C) -> A OR B OR C
209    """
210    if isinstance(expression, exp.Connector):
211        for node in expression.args.values():
212            child = node.unnest()
213            if isinstance(child, expression.__class__):
214                node.replace(child)
215    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
218def simplify_connectors(expression, root=True):
219    def _simplify_connectors(expression, left, right):
220        if left == right:
221            return left
222        if isinstance(expression, exp.And):
223            if is_false(left) or is_false(right):
224                return exp.false()
225            if is_null(left) or is_null(right):
226                return exp.null()
227            if always_true(left) and always_true(right):
228                return exp.true()
229            if always_true(left):
230                return right
231            if always_true(right):
232                return left
233            return _simplify_comparison(expression, left, right)
234        elif isinstance(expression, exp.Or):
235            if always_true(left) or always_true(right):
236                return exp.true()
237            if is_false(left) and is_false(right):
238                return exp.false()
239            if (
240                (is_null(left) and is_null(right))
241                or (is_null(left) and is_false(right))
242                or (is_false(left) and is_null(right))
243            ):
244                return exp.null()
245            if is_false(left):
246                return right
247            if is_false(right):
248                return left
249            return _simplify_comparison(expression, left, right, or_=True)
250
251    if isinstance(expression, exp.Connector):
252        return _flat_simplify(expression, _simplify_connectors, root)
253    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
337def remove_complements(expression, root=True):
338    """
339    Removing complements.
340
341    A AND NOT A -> FALSE
342    A OR NOT A -> TRUE
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if is_complement(a, b):
349                return complement
350    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
353def uniq_sort(expression, root=True):
354    """
355    Uniq and sort a connector.
356
357    C AND A AND B AND B -> A AND B AND C
358    """
359    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
360        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
361        flattened = tuple(expression.flatten())
362        deduped = {gen(e): e for e in flattened}
363        arr = tuple(deduped.items())
364
365        # check if the operands are already sorted, if not sort them
366        # A AND C AND B -> A AND B AND C
367        for i, (sql, e) in enumerate(arr[1:]):
368            if sql < arr[i][0]:
369                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
370                break
371        else:
372            # we didn't have to sort but maybe we need to dedup
373            if len(deduped) < len(flattened):
374                expression = result_func(*deduped.values(), copy=False)
375
376    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
379def absorb_and_eliminate(expression, root=True):
380    """
381    absorption:
382        A AND (A OR B) -> A
383        A OR (A AND B) -> A
384        A AND (NOT A OR B) -> A AND B
385        A OR (NOT A AND B) -> A OR B
386    elimination:
387        (A AND B) OR (A AND NOT B) -> A
388        (A OR B) AND (A OR NOT B) -> A
389    """
390    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
391        kind = exp.Or if isinstance(expression, exp.And) else exp.And
392
393        for a, b in itertools.permutations(expression.flatten(), 2):
394            if isinstance(a, kind):
395                aa, ab = a.unnest_operands()
396
397                # absorb
398                if is_complement(b, aa):
399                    aa.replace(exp.true() if kind == exp.And else exp.false())
400                elif is_complement(b, ab):
401                    ab.replace(exp.true() if kind == exp.And else exp.false())
402                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
403                    a.replace(exp.false() if kind == exp.And else exp.true())
404                elif isinstance(b, kind):
405                    # eliminate
406                    rhs = b.unnest_operands()
407                    ba, bb = rhs
408
409                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
410                        a.replace(aa)
411                        b.replace(aa)
412                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
413                        a.replace(ab)
414                        b.replace(ab)
415
416    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
419def propagate_constants(expression, root=True):
420    """
421    Propagate constants for conjunctions in DNF:
422
423    SELECT * FROM t WHERE a = b AND b = 5 becomes
424    SELECT * FROM t WHERE a = 5 AND b = 5
425
426    Reference: https://www.sqlite.org/optoverview.html
427    """
428
429    if (
430        isinstance(expression, exp.And)
431        and (root or not expression.same_parent)
432        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
433    ):
434        constant_mapping = {}
435        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
436            if isinstance(expr, exp.EQ):
437                l, r = expr.left, expr.right
438
439                # TODO: create a helper that can be used to detect nested literal expressions such
440                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
441                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
442                    constant_mapping[l] = (id(l), r)
443
444        if constant_mapping:
445            for column in find_all_in_scope(expression, exp.Column):
446                parent = column.parent
447                column_id, constant = constant_mapping.get(column) or (None, None)
448                if (
449                    column_id is not None
450                    and id(column) != column_id
451                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
452                ):
453                    column.replace(constant.copy())
454
455    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
530def simplify_literals(expression, root=True):
531    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
532        return _flat_simplify(expression, _simplify_binary, root)
533
534    if isinstance(expression, exp.Neg):
535        this = expression.this
536        if this.is_number:
537            value = this.name
538            if value[0] == "-":
539                return exp.Literal.number(value[1:])
540            return exp.Literal.number(f"-{value}")
541
542    if type(expression) in INVERSE_DATE_OPS:
543        return _simplify_binary(expression, expression.this, expression.interval()) or expression
544
545    return expression
def simplify_parens(expression):
619def simplify_parens(expression):
620    if not isinstance(expression, exp.Paren):
621        return expression
622
623    this = expression.this
624    parent = expression.parent
625
626    if not isinstance(this, exp.Select) and (
627        not isinstance(parent, (exp.Condition, exp.Binary))
628        or isinstance(parent, exp.Paren)
629        or not isinstance(this, exp.Binary)
630        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
631        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
632        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
633        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
634    ):
635        return this
636    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
659def simplify_coalesce(expression):
660    # COALESCE(x) -> x
661    if (
662        isinstance(expression, exp.Coalesce)
663        and (not expression.expressions or _is_nonnull_constant(expression.this))
664        # COALESCE is also used as a Spark partitioning hint
665        and not isinstance(expression.parent, exp.Hint)
666    ):
667        return expression.this
668
669    if not isinstance(expression, COMPARISONS):
670        return expression
671
672    if isinstance(expression.left, exp.Coalesce):
673        coalesce = expression.left
674        other = expression.right
675    elif isinstance(expression.right, exp.Coalesce):
676        coalesce = expression.right
677        other = expression.left
678    else:
679        return expression
680
681    # This transformation is valid for non-constants,
682    # but it really only does anything if they are both constants.
683    if not _is_constant(other):
684        return expression
685
686    # Find the first constant arg
687    for arg_index, arg in enumerate(coalesce.expressions):
688        if _is_constant(arg):
689            break
690    else:
691        return expression
692
693    coalesce.set("expressions", coalesce.expressions[:arg_index])
694
695    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
696    # since we already remove COALESCE at the top of this function.
697    coalesce = coalesce if coalesce.expressions else coalesce.this
698
699    # This expression is more complex than when we started, but it will get simplified further
700    return exp.paren(
701        exp.or_(
702            exp.and_(
703                coalesce.is_(exp.null()).not_(copy=False),
704                expression.copy(),
705                copy=False,
706            ),
707            exp.and_(
708                coalesce.is_(exp.null()),
709                type(expression)(this=arg.copy(), expression=other.copy()),
710                copy=False,
711            ),
712            copy=False,
713        )
714    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
720def simplify_concat(expression):
721    """Reduces all groups that contain string literals by concatenating them."""
722    if not isinstance(expression, CONCATS) or (
723        # We can't reduce a CONCAT_WS call if we don't statically know the separator
724        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
725    ):
726        return expression
727
728    if isinstance(expression, exp.ConcatWs):
729        sep_expr, *expressions = expression.expressions
730        sep = sep_expr.name
731        concat_type = exp.ConcatWs
732        args = {}
733    else:
734        expressions = expression.expressions
735        sep = ""
736        concat_type = exp.Concat
737        args = {
738            "safe": expression.args.get("safe"),
739            "coalesce": expression.args.get("coalesce"),
740        }
741
742    new_args = []
743    for is_string_group, group in itertools.groupby(
744        expressions or expression.flatten(), lambda e: e.is_string
745    ):
746        if is_string_group:
747            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
748        else:
749            new_args.extend(group)
750
751    if len(new_args) == 1 and new_args[0].is_string:
752        return new_args[0]
753
754    if concat_type is exp.ConcatWs:
755        new_args = [sep_expr] + new_args
756
757    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
760def simplify_conditionals(expression):
761    """Simplifies expressions like IF, CASE if their condition is statically known."""
762    if isinstance(expression, exp.Case):
763        this = expression.this
764        for case in expression.args["ifs"]:
765            cond = case.this
766            if this:
767                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
768                cond = cond.replace(this.pop().eq(cond))
769
770            if always_true(cond):
771                return case.args["true"]
772
773            if always_false(cond):
774                case.pop()
775                if not expression.args["ifs"]:
776                    return expression.args.get("default") or exp.null()
777    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
778        if always_true(expression.this):
779            return expression.args["true"]
780        if always_false(expression.this):
781            return expression.args.get("false") or exp.null()
782
783    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
786def simplify_startswith(expression: exp.Expression) -> exp.Expression:
787    """
788    Reduces a prefix check to either TRUE or FALSE if both the string and the
789    prefix are statically known.
790
791    Example:
792        >>> from sqlglot import parse_one
793        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
794        'TRUE'
795    """
796    if (
797        isinstance(expression, exp.StartsWith)
798        and expression.this.is_string
799        and expression.expression.is_string
800    ):
801        return exp.convert(expression.name.startswith(expression.expression.name))
802
803    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GTE'>}
def simplify_datetrunc(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
930def sort_comparison(expression: exp.Expression) -> exp.Expression:
931    if expression.__class__ in COMPLEMENT_COMPARISONS:
932        l, r = expression.this, expression.expression
933        l_column = isinstance(l, exp.Column)
934        r_column = isinstance(r, exp.Column)
935        l_const = _is_constant(l)
936        r_const = _is_constant(r)
937
938        if (l_column and not r_column) or (r_const and not l_const):
939            return expression
940        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
941            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
942                this=r, expression=l
943            )
944    return expression
JOINS = {('', ''), ('', 'INNER'), ('RIGHT', 'OUTER'), ('RIGHT', '')}
def remove_where_true(expression):
958def remove_where_true(expression):
959    for where in expression.find_all(exp.Where):
960        if always_true(where.this):
961            where.parent.set("where", None)
962    for join in expression.find_all(exp.Join):
963        if (
964            always_true(join.args.get("on"))
965            and not join.args.get("using")
966            and not join.args.get("method")
967            and (join.side, join.kind) in JOINS
968        ):
969            join.set("on", None)
970            join.set("side", None)
971            join.set("kind", "CROSS")
def always_true(expression):
974def always_true(expression):
975    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
976        expression, exp.Literal
977    )
def always_false(expression):
980def always_false(expression):
981    return is_false(expression) or is_null(expression)
def is_complement(a, b):
984def is_complement(a, b):
985    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
988def is_false(a: exp.Expression) -> bool:
989    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
992def is_null(a: exp.Expression) -> bool:
993    return type(a) is exp.Null
def eval_boolean(expression, a, b):
 996def eval_boolean(expression, a, b):
 997    if isinstance(expression, (exp.EQ, exp.Is)):
 998        return boolean_literal(a == b)
 999    if isinstance(expression, exp.NEQ):
1000        return boolean_literal(a != b)
1001    if isinstance(expression, exp.GT):
1002        return boolean_literal(a > b)
1003    if isinstance(expression, exp.GTE):
1004        return boolean_literal(a >= b)
1005    if isinstance(expression, exp.LT):
1006        return boolean_literal(a < b)
1007    if isinstance(expression, exp.LTE):
1008        return boolean_literal(a <= b)
1009    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1012def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1013    if isinstance(value, datetime.datetime):
1014        return value.date()
1015    if isinstance(value, datetime.date):
1016        return value
1017    try:
1018        return datetime.datetime.fromisoformat(value).date()
1019    except ValueError:
1020        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1023def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1024    if isinstance(value, datetime.datetime):
1025        return value
1026    if isinstance(value, datetime.date):
1027        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1028    try:
1029        return datetime.datetime.fromisoformat(value)
1030    except ValueError:
1031        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1034def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1035    if not value:
1036        return None
1037    if to.is_type(exp.DataType.Type.DATE):
1038        return cast_as_date(value)
1039    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1040        return cast_as_datetime(value)
1041    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1044def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1045    if isinstance(cast, exp.Cast):
1046        to = cast.to
1047    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1048        to = exp.DataType.build(exp.DataType.Type.DATE)
1049    else:
1050        return None
1051
1052    if isinstance(cast.this, exp.Literal):
1053        value: t.Any = cast.this.name
1054    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1055        value = extract_date(cast.this)
1056    else:
1057        return None
1058    return cast_value(value, to)
def extract_interval(expression):
1065def extract_interval(expression):
1066    try:
1067        n = int(expression.name)
1068        unit = expression.text("unit").lower()
1069        return interval(unit, n)
1070    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1071        return None
def date_literal(date):
1074def date_literal(date):
1075    return exp.cast(
1076        exp.Literal.string(date),
1077        (
1078            exp.DataType.Type.DATETIME
1079            if isinstance(date, datetime.datetime)
1080            else exp.DataType.Type.DATE
1081        ),
1082    )
def interval(unit: str, n: int = 1):
1085def interval(unit: str, n: int = 1):
1086    from dateutil.relativedelta import relativedelta
1087
1088    if unit == "year":
1089        return relativedelta(years=1 * n)
1090    if unit == "quarter":
1091        return relativedelta(months=3 * n)
1092    if unit == "month":
1093        return relativedelta(months=1 * n)
1094    if unit == "week":
1095        return relativedelta(weeks=1 * n)
1096    if unit == "day":
1097        return relativedelta(days=1 * n)
1098    if unit == "hour":
1099        return relativedelta(hours=1 * n)
1100    if unit == "minute":
1101        return relativedelta(minutes=1 * n)
1102    if unit == "second":
1103        return relativedelta(seconds=1 * n)
1104
1105    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1108def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1109    if unit == "year":
1110        return d.replace(month=1, day=1)
1111    if unit == "quarter":
1112        if d.month <= 3:
1113            return d.replace(month=1, day=1)
1114        elif d.month <= 6:
1115            return d.replace(month=4, day=1)
1116        elif d.month <= 9:
1117            return d.replace(month=7, day=1)
1118        else:
1119            return d.replace(month=10, day=1)
1120    if unit == "month":
1121        return d.replace(month=d.month, day=1)
1122    if unit == "week":
1123        # Assuming week starts on Monday (0) and ends on Sunday (6)
1124        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1125    if unit == "day":
1126        return d
1127
1128    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1131def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1132    floor = date_floor(d, unit, dialect)
1133
1134    if floor == d:
1135        return d
1136
1137    return floor + interval(unit)
def boolean_literal(condition):
1140def boolean_literal(condition):
1141    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1170def gen(expression: t.Any) -> str:
1171    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1172
1173    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1174    generator is expensive so we have a bare minimum sql generator here.
1175    """
1176    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1179class Gen:
1180    def __init__(self):
1181        self.stack = []
1182        self.sqls = []
1183
1184    def gen(self, expression: exp.Expression) -> str:
1185        self.stack = [expression]
1186        self.sqls.clear()
1187
1188        while self.stack:
1189            node = self.stack.pop()
1190
1191            if isinstance(node, exp.Expression):
1192                exp_handler_name = f"{node.key}_sql"
1193
1194                if hasattr(self, exp_handler_name):
1195                    getattr(self, exp_handler_name)(node)
1196                elif isinstance(node, exp.Func):
1197                    self._function(node)
1198                else:
1199                    key = node.key.upper()
1200                    self.stack.append(f"{key} " if self._args(node) else key)
1201            elif type(node) is list:
1202                for n in reversed(node):
1203                    if n is not None:
1204                        self.stack.extend((n, ","))
1205                if node:
1206                    self.stack.pop()
1207            else:
1208                if node is not None:
1209                    self.sqls.append(str(node))
1210
1211        return "".join(self.sqls)
1212
1213    def add_sql(self, e: exp.Add) -> None:
1214        self._binary(e, " + ")
1215
1216    def alias_sql(self, e: exp.Alias) -> None:
1217        self.stack.extend(
1218            (
1219                e.args.get("alias"),
1220                " AS ",
1221                e.args.get("this"),
1222            )
1223        )
1224
1225    def and_sql(self, e: exp.And) -> None:
1226        self._binary(e, " AND ")
1227
1228    def anonymous_sql(self, e: exp.Anonymous) -> None:
1229        this = e.this
1230        if isinstance(this, str):
1231            name = this.upper()
1232        elif isinstance(this, exp.Identifier):
1233            name = this.this
1234            name = f'"{name}"' if this.quoted else name.upper()
1235        else:
1236            raise ValueError(
1237                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1238            )
1239
1240        self.stack.extend(
1241            (
1242                ")",
1243                e.expressions,
1244                "(",
1245                name,
1246            )
1247        )
1248
1249    def between_sql(self, e: exp.Between) -> None:
1250        self.stack.extend(
1251            (
1252                e.args.get("high"),
1253                " AND ",
1254                e.args.get("low"),
1255                " BETWEEN ",
1256                e.this,
1257            )
1258        )
1259
1260    def boolean_sql(self, e: exp.Boolean) -> None:
1261        self.stack.append("TRUE" if e.this else "FALSE")
1262
1263    def bracket_sql(self, e: exp.Bracket) -> None:
1264        self.stack.extend(
1265            (
1266                "]",
1267                e.expressions,
1268                "[",
1269                e.this,
1270            )
1271        )
1272
1273    def column_sql(self, e: exp.Column) -> None:
1274        for p in reversed(e.parts):
1275            self.stack.extend((p, "."))
1276        self.stack.pop()
1277
1278    def datatype_sql(self, e: exp.DataType) -> None:
1279        self._args(e, 1)
1280        self.stack.append(f"{e.this.name} ")
1281
1282    def div_sql(self, e: exp.Div) -> None:
1283        self._binary(e, " / ")
1284
1285    def dot_sql(self, e: exp.Dot) -> None:
1286        self._binary(e, ".")
1287
1288    def eq_sql(self, e: exp.EQ) -> None:
1289        self._binary(e, " = ")
1290
1291    def from_sql(self, e: exp.From) -> None:
1292        self.stack.extend((e.this, "FROM "))
1293
1294    def gt_sql(self, e: exp.GT) -> None:
1295        self._binary(e, " > ")
1296
1297    def gte_sql(self, e: exp.GTE) -> None:
1298        self._binary(e, " >= ")
1299
1300    def identifier_sql(self, e: exp.Identifier) -> None:
1301        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1302
1303    def ilike_sql(self, e: exp.ILike) -> None:
1304        self._binary(e, " ILIKE ")
1305
1306    def in_sql(self, e: exp.In) -> None:
1307        self.stack.append(")")
1308        self._args(e, 1)
1309        self.stack.extend(
1310            (
1311                "(",
1312                " IN ",
1313                e.this,
1314            )
1315        )
1316
1317    def intdiv_sql(self, e: exp.IntDiv) -> None:
1318        self._binary(e, " DIV ")
1319
1320    def is_sql(self, e: exp.Is) -> None:
1321        self._binary(e, " IS ")
1322
1323    def like_sql(self, e: exp.Like) -> None:
1324        self._binary(e, " Like ")
1325
1326    def literal_sql(self, e: exp.Literal) -> None:
1327        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1328
1329    def lt_sql(self, e: exp.LT) -> None:
1330        self._binary(e, " < ")
1331
1332    def lte_sql(self, e: exp.LTE) -> None:
1333        self._binary(e, " <= ")
1334
1335    def mod_sql(self, e: exp.Mod) -> None:
1336        self._binary(e, " % ")
1337
1338    def mul_sql(self, e: exp.Mul) -> None:
1339        self._binary(e, " * ")
1340
1341    def neg_sql(self, e: exp.Neg) -> None:
1342        self._unary(e, "-")
1343
1344    def neq_sql(self, e: exp.NEQ) -> None:
1345        self._binary(e, " <> ")
1346
1347    def not_sql(self, e: exp.Not) -> None:
1348        self._unary(e, "NOT ")
1349
1350    def null_sql(self, e: exp.Null) -> None:
1351        self.stack.append("NULL")
1352
1353    def or_sql(self, e: exp.Or) -> None:
1354        self._binary(e, " OR ")
1355
1356    def paren_sql(self, e: exp.Paren) -> None:
1357        self.stack.extend(
1358            (
1359                ")",
1360                e.this,
1361                "(",
1362            )
1363        )
1364
1365    def sub_sql(self, e: exp.Sub) -> None:
1366        self._binary(e, " - ")
1367
1368    def subquery_sql(self, e: exp.Subquery) -> None:
1369        self._args(e, 2)
1370        alias = e.args.get("alias")
1371        if alias:
1372            self.stack.append(alias)
1373        self.stack.extend((")", e.this, "("))
1374
1375    def table_sql(self, e: exp.Table) -> None:
1376        self._args(e, 4)
1377        alias = e.args.get("alias")
1378        if alias:
1379            self.stack.append(alias)
1380        for p in reversed(e.parts):
1381            self.stack.extend((p, "."))
1382        self.stack.pop()
1383
1384    def tablealias_sql(self, e: exp.TableAlias) -> None:
1385        columns = e.columns
1386
1387        if columns:
1388            self.stack.extend((")", columns, "("))
1389
1390        self.stack.extend((e.this, " AS "))
1391
1392    def var_sql(self, e: exp.Var) -> None:
1393        self.stack.append(e.this)
1394
1395    def _binary(self, e: exp.Binary, op: str) -> None:
1396        self.stack.extend((e.expression, op, e.this))
1397
1398    def _unary(self, e: exp.Unary, op: str) -> None:
1399        self.stack.extend((e.this, op))
1400
1401    def _function(self, e: exp.Func) -> None:
1402        self.stack.extend(
1403            (
1404                ")",
1405                list(e.args.values()),
1406                "(",
1407                e.sql_name(),
1408            )
1409        )
1410
1411    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1412        kvs = []
1413        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1414
1415        for k in arg_types or arg_types:
1416            v = node.args.get(k)
1417
1418            if v is not None:
1419                kvs.append([f":{k}", v])
1420        if kvs:
1421            self.stack.append(kvs)
1422            return True
1423        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1184    def gen(self, expression: exp.Expression) -> str:
1185        self.stack = [expression]
1186        self.sqls.clear()
1187
1188        while self.stack:
1189            node = self.stack.pop()
1190
1191            if isinstance(node, exp.Expression):
1192                exp_handler_name = f"{node.key}_sql"
1193
1194                if hasattr(self, exp_handler_name):
1195                    getattr(self, exp_handler_name)(node)
1196                elif isinstance(node, exp.Func):
1197                    self._function(node)
1198                else:
1199                    key = node.key.upper()
1200                    self.stack.append(f"{key} " if self._args(node) else key)
1201            elif type(node) is list:
1202                for n in reversed(node):
1203                    if n is not None:
1204                        self.stack.extend((n, ","))
1205                if node:
1206                    self.stack.pop()
1207            else:
1208                if node is not None:
1209                    self.sqls.append(str(node))
1210
1211        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1213    def add_sql(self, e: exp.Add) -> None:
1214        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1216    def alias_sql(self, e: exp.Alias) -> None:
1217        self.stack.extend(
1218            (
1219                e.args.get("alias"),
1220                " AS ",
1221                e.args.get("this"),
1222            )
1223        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1225    def and_sql(self, e: exp.And) -> None:
1226        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1228    def anonymous_sql(self, e: exp.Anonymous) -> None:
1229        this = e.this
1230        if isinstance(this, str):
1231            name = this.upper()
1232        elif isinstance(this, exp.Identifier):
1233            name = this.this
1234            name = f'"{name}"' if this.quoted else name.upper()
1235        else:
1236            raise ValueError(
1237                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1238            )
1239
1240        self.stack.extend(
1241            (
1242                ")",
1243                e.expressions,
1244                "(",
1245                name,
1246            )
1247        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1249    def between_sql(self, e: exp.Between) -> None:
1250        self.stack.extend(
1251            (
1252                e.args.get("high"),
1253                " AND ",
1254                e.args.get("low"),
1255                " BETWEEN ",
1256                e.this,
1257            )
1258        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1260    def boolean_sql(self, e: exp.Boolean) -> None:
1261        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1263    def bracket_sql(self, e: exp.Bracket) -> None:
1264        self.stack.extend(
1265            (
1266                "]",
1267                e.expressions,
1268                "[",
1269                e.this,
1270            )
1271        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1273    def column_sql(self, e: exp.Column) -> None:
1274        for p in reversed(e.parts):
1275            self.stack.extend((p, "."))
1276        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1278    def datatype_sql(self, e: exp.DataType) -> None:
1279        self._args(e, 1)
1280        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1282    def div_sql(self, e: exp.Div) -> None:
1283        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1285    def dot_sql(self, e: exp.Dot) -> None:
1286        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1288    def eq_sql(self, e: exp.EQ) -> None:
1289        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1291    def from_sql(self, e: exp.From) -> None:
1292        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1294    def gt_sql(self, e: exp.GT) -> None:
1295        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1297    def gte_sql(self, e: exp.GTE) -> None:
1298        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1300    def identifier_sql(self, e: exp.Identifier) -> None:
1301        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1303    def ilike_sql(self, e: exp.ILike) -> None:
1304        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1306    def in_sql(self, e: exp.In) -> None:
1307        self.stack.append(")")
1308        self._args(e, 1)
1309        self.stack.extend(
1310            (
1311                "(",
1312                " IN ",
1313                e.this,
1314            )
1315        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1317    def intdiv_sql(self, e: exp.IntDiv) -> None:
1318        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1320    def is_sql(self, e: exp.Is) -> None:
1321        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1323    def like_sql(self, e: exp.Like) -> None:
1324        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1326    def literal_sql(self, e: exp.Literal) -> None:
1327        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1329    def lt_sql(self, e: exp.LT) -> None:
1330        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1332    def lte_sql(self, e: exp.LTE) -> None:
1333        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1335    def mod_sql(self, e: exp.Mod) -> None:
1336        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1338    def mul_sql(self, e: exp.Mul) -> None:
1339        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1341    def neg_sql(self, e: exp.Neg) -> None:
1342        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1344    def neq_sql(self, e: exp.NEQ) -> None:
1345        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1347    def not_sql(self, e: exp.Not) -> None:
1348        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1350    def null_sql(self, e: exp.Null) -> None:
1351        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1353    def or_sql(self, e: exp.Or) -> None:
1354        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1356    def paren_sql(self, e: exp.Paren) -> None:
1357        self.stack.extend(
1358            (
1359                ")",
1360                e.this,
1361                "(",
1362            )
1363        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1365    def sub_sql(self, e: exp.Sub) -> None:
1366        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1368    def subquery_sql(self, e: exp.Subquery) -> None:
1369        self._args(e, 2)
1370        alias = e.args.get("alias")
1371        if alias:
1372            self.stack.append(alias)
1373        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1375    def table_sql(self, e: exp.Table) -> None:
1376        self._args(e, 4)
1377        alias = e.args.get("alias")
1378        if alias:
1379            self.stack.append(alias)
1380        for p in reversed(e.parts):
1381            self.stack.extend((p, "."))
1382        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1384    def tablealias_sql(self, e: exp.TableAlias) -> None:
1385        columns = e.columns
1386
1387        if columns:
1388            self.stack.extend((")", columns, "("))
1389
1390        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1392    def var_sql(self, e: exp.Var) -> None:
1393        self.stack.append(e.this)