sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, is_iterable, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25 26class UnsupportedUnit(Exception): 27 pass 28 29 30def simplify( 31 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 32): 33 """ 34 Rewrite sqlglot AST to simplify expressions. 35 36 Example: 37 >>> import sqlglot 38 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 39 >>> simplify(expression).sql() 40 'TRUE' 41 42 Args: 43 expression (sqlglot.Expression): expression to simplify 44 constant_propagation: whether or not the constant propagation rule should be used 45 46 Returns: 47 sqlglot.Expression: simplified expression 48 """ 49 50 dialect = Dialect.get_or_raise(dialect) 51 52 def _simplify(expression, root=True): 53 if expression.meta.get(FINAL): 54 return expression 55 56 # group by expressions cannot be simplified, for example 57 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 58 # the projection must exactly match the group by key 59 group = expression.args.get("group") 60 61 if group and hasattr(expression, "selects"): 62 groups = set(group.expressions) 63 group.meta[FINAL] = True 64 65 for e in expression.selects: 66 for node, *_ in e.walk(): 67 if node in groups: 68 e.meta[FINAL] = True 69 break 70 71 having = expression.args.get("having") 72 if having: 73 for node, *_ in having.walk(): 74 if node in groups: 75 having.meta[FINAL] = True 76 break 77 78 # Pre-order transformations 79 node = expression 80 node = rewrite_between(node) 81 node = uniq_sort(node, root) 82 node = absorb_and_eliminate(node, root) 83 node = simplify_concat(node) 84 node = simplify_conditionals(node) 85 86 if constant_propagation: 87 node = propagate_constants(node, root) 88 89 exp.replace_children(node, lambda e: _simplify(e, False)) 90 91 # Post-order transformations 92 node = simplify_not(node) 93 node = flatten(node) 94 node = simplify_connectors(node, root) 95 node = remove_complements(node, root) 96 node = simplify_coalesce(node) 97 node.parent = expression.parent 98 node = simplify_literals(node, root) 99 node = simplify_equality(node) 100 node = simplify_parens(node) 101 node = simplify_datetrunc(node, dialect) 102 node = sort_comparison(node) 103 104 if root: 105 expression.replace(node) 106 107 return node 108 109 expression = while_changing(expression, _simplify) 110 remove_where_true(expression) 111 return expression 112 113 114def catch(*exceptions): 115 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 116 117 def decorator(func): 118 def wrapped(expression, *args, **kwargs): 119 try: 120 return func(expression, *args, **kwargs) 121 except exceptions: 122 return expression 123 124 return wrapped 125 126 return decorator 127 128 129def rewrite_between(expression: exp.Expression) -> exp.Expression: 130 """Rewrite x between y and z to x >= y AND x <= z. 131 132 This is done because comparison simplification is only done on lt/lte/gt/gte. 133 """ 134 if isinstance(expression, exp.Between): 135 negate = isinstance(expression.parent, exp.Not) 136 137 expression = exp.and_( 138 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 139 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 140 copy=False, 141 ) 142 143 if negate: 144 expression = exp.paren(expression, copy=False) 145 146 return expression 147 148 149COMPLEMENT_COMPARISONS = { 150 exp.LT: exp.GTE, 151 exp.GT: exp.LTE, 152 exp.LTE: exp.GT, 153 exp.GTE: exp.LT, 154 exp.EQ: exp.NEQ, 155 exp.NEQ: exp.EQ, 156} 157 158 159def simplify_not(expression): 160 """ 161 Demorgan's Law 162 NOT (x OR y) -> NOT x AND NOT y 163 NOT (x AND y) -> NOT x OR NOT y 164 """ 165 if isinstance(expression, exp.Not): 166 this = expression.this 167 if is_null(this): 168 return exp.null() 169 if this.__class__ in COMPLEMENT_COMPARISONS: 170 return COMPLEMENT_COMPARISONS[this.__class__]( 171 this=this.this, expression=this.expression 172 ) 173 if isinstance(this, exp.Paren): 174 condition = this.unnest() 175 if isinstance(condition, exp.And): 176 return exp.or_( 177 exp.not_(condition.left, copy=False), 178 exp.not_(condition.right, copy=False), 179 copy=False, 180 ) 181 if isinstance(condition, exp.Or): 182 return exp.and_( 183 exp.not_(condition.left, copy=False), 184 exp.not_(condition.right, copy=False), 185 copy=False, 186 ) 187 if is_null(condition): 188 return exp.null() 189 if always_true(this): 190 return exp.false() 191 if is_false(this): 192 return exp.true() 193 if isinstance(this, exp.Not): 194 # double negation 195 # NOT NOT x -> x 196 return this.this 197 return expression 198 199 200def flatten(expression): 201 """ 202 A AND (B AND C) -> A AND B AND C 203 A OR (B OR C) -> A OR B OR C 204 """ 205 if isinstance(expression, exp.Connector): 206 for node in expression.args.values(): 207 child = node.unnest() 208 if isinstance(child, expression.__class__): 209 node.replace(child) 210 return expression 211 212 213def simplify_connectors(expression, root=True): 214 def _simplify_connectors(expression, left, right): 215 if left == right: 216 return left 217 if isinstance(expression, exp.And): 218 if is_false(left) or is_false(right): 219 return exp.false() 220 if is_null(left) or is_null(right): 221 return exp.null() 222 if always_true(left) and always_true(right): 223 return exp.true() 224 if always_true(left): 225 return right 226 if always_true(right): 227 return left 228 return _simplify_comparison(expression, left, right) 229 elif isinstance(expression, exp.Or): 230 if always_true(left) or always_true(right): 231 return exp.true() 232 if is_false(left) and is_false(right): 233 return exp.false() 234 if ( 235 (is_null(left) and is_null(right)) 236 or (is_null(left) and is_false(right)) 237 or (is_false(left) and is_null(right)) 238 ): 239 return exp.null() 240 if is_false(left): 241 return right 242 if is_false(right): 243 return left 244 return _simplify_comparison(expression, left, right, or_=True) 245 246 if isinstance(expression, exp.Connector): 247 return _flat_simplify(expression, _simplify_connectors, root) 248 return expression 249 250 251LT_LTE = (exp.LT, exp.LTE) 252GT_GTE = (exp.GT, exp.GTE) 253 254COMPARISONS = ( 255 *LT_LTE, 256 *GT_GTE, 257 exp.EQ, 258 exp.NEQ, 259 exp.Is, 260) 261 262INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 263 exp.LT: exp.GT, 264 exp.GT: exp.LT, 265 exp.LTE: exp.GTE, 266 exp.GTE: exp.LTE, 267} 268 269NONDETERMINISTIC = (exp.Rand, exp.Randn) 270 271 272def _simplify_comparison(expression, left, right, or_=False): 273 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 274 ll, lr = left.args.values() 275 rl, rr = right.args.values() 276 277 largs = {ll, lr} 278 rargs = {rl, rr} 279 280 matching = largs & rargs 281 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 282 283 if matching and columns: 284 try: 285 l = first(largs - columns) 286 r = first(rargs - columns) 287 except StopIteration: 288 return expression 289 290 if l.is_number and r.is_number: 291 l = float(l.name) 292 r = float(r.name) 293 elif l.is_string and r.is_string: 294 l = l.name 295 r = r.name 296 else: 297 l = extract_date(l) 298 if not l: 299 return None 300 r = extract_date(r) 301 if not r: 302 return None 303 304 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 305 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 306 return left if (av > bv if or_ else av <= bv) else right 307 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 308 return left if (av < bv if or_ else av >= bv) else right 309 310 # we can't ever shortcut to true because the column could be null 311 if not or_: 312 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 313 if av <= bv: 314 return exp.false() 315 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 316 if av >= bv: 317 return exp.false() 318 elif isinstance(a, exp.EQ): 319 if isinstance(b, exp.LT): 320 return exp.false() if av >= bv else a 321 if isinstance(b, exp.LTE): 322 return exp.false() if av > bv else a 323 if isinstance(b, exp.GT): 324 return exp.false() if av <= bv else a 325 if isinstance(b, exp.GTE): 326 return exp.false() if av < bv else a 327 if isinstance(b, exp.NEQ): 328 return exp.false() if av == bv else a 329 return None 330 331 332def remove_complements(expression, root=True): 333 """ 334 Removing complements. 335 336 A AND NOT A -> FALSE 337 A OR NOT A -> TRUE 338 """ 339 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 340 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 341 342 for a, b in itertools.permutations(expression.flatten(), 2): 343 if is_complement(a, b): 344 return complement 345 return expression 346 347 348def uniq_sort(expression, root=True): 349 """ 350 Uniq and sort a connector. 351 352 C AND A AND B AND B -> A AND B AND C 353 """ 354 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 355 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 356 flattened = tuple(expression.flatten()) 357 deduped = {gen(e): e for e in flattened} 358 arr = tuple(deduped.items()) 359 360 # check if the operands are already sorted, if not sort them 361 # A AND C AND B -> A AND B AND C 362 for i, (sql, e) in enumerate(arr[1:]): 363 if sql < arr[i][0]: 364 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 365 break 366 else: 367 # we didn't have to sort but maybe we need to dedup 368 if len(deduped) < len(flattened): 369 expression = result_func(*deduped.values(), copy=False) 370 371 return expression 372 373 374def absorb_and_eliminate(expression, root=True): 375 """ 376 absorption: 377 A AND (A OR B) -> A 378 A OR (A AND B) -> A 379 A AND (NOT A OR B) -> A AND B 380 A OR (NOT A AND B) -> A OR B 381 elimination: 382 (A AND B) OR (A AND NOT B) -> A 383 (A OR B) AND (A OR NOT B) -> A 384 """ 385 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 386 kind = exp.Or if isinstance(expression, exp.And) else exp.And 387 388 for a, b in itertools.permutations(expression.flatten(), 2): 389 if isinstance(a, kind): 390 aa, ab = a.unnest_operands() 391 392 # absorb 393 if is_complement(b, aa): 394 aa.replace(exp.true() if kind == exp.And else exp.false()) 395 elif is_complement(b, ab): 396 ab.replace(exp.true() if kind == exp.And else exp.false()) 397 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 398 a.replace(exp.false() if kind == exp.And else exp.true()) 399 elif isinstance(b, kind): 400 # eliminate 401 rhs = b.unnest_operands() 402 ba, bb = rhs 403 404 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 405 a.replace(aa) 406 b.replace(aa) 407 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 408 a.replace(ab) 409 b.replace(ab) 410 411 return expression 412 413 414def propagate_constants(expression, root=True): 415 """ 416 Propagate constants for conjunctions in DNF: 417 418 SELECT * FROM t WHERE a = b AND b = 5 becomes 419 SELECT * FROM t WHERE a = 5 AND b = 5 420 421 Reference: https://www.sqlite.org/optoverview.html 422 """ 423 424 if ( 425 isinstance(expression, exp.And) 426 and (root or not expression.same_parent) 427 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 428 ): 429 constant_mapping = {} 430 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 431 if isinstance(expr, exp.EQ): 432 l, r = expr.left, expr.right 433 434 # TODO: create a helper that can be used to detect nested literal expressions such 435 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 436 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 437 constant_mapping[l] = (id(l), r) 438 439 if constant_mapping: 440 for column in find_all_in_scope(expression, exp.Column): 441 parent = column.parent 442 column_id, constant = constant_mapping.get(column) or (None, None) 443 if ( 444 column_id is not None 445 and id(column) != column_id 446 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 447 ): 448 column.replace(constant.copy()) 449 450 return expression 451 452 453INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 454 exp.DateAdd: exp.Sub, 455 exp.DateSub: exp.Add, 456 exp.DatetimeAdd: exp.Sub, 457 exp.DatetimeSub: exp.Add, 458} 459 460INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 461 **INVERSE_DATE_OPS, 462 exp.Add: exp.Sub, 463 exp.Sub: exp.Add, 464} 465 466 467def _is_number(expression: exp.Expression) -> bool: 468 return expression.is_number 469 470 471def _is_interval(expression: exp.Expression) -> bool: 472 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 473 474 475@catch(ModuleNotFoundError, UnsupportedUnit) 476def simplify_equality(expression: exp.Expression) -> exp.Expression: 477 """ 478 Use the subtraction and addition properties of equality to simplify expressions: 479 480 x + 1 = 3 becomes x = 2 481 482 There are two binary operations in the above expression: + and = 483 Here's how we reference all the operands in the code below: 484 485 l r 486 x + 1 = 3 487 a b 488 """ 489 if isinstance(expression, COMPARISONS): 490 l, r = expression.left, expression.right 491 492 if not l.__class__ in INVERSE_OPS: 493 return expression 494 495 if r.is_number: 496 a_predicate = _is_number 497 b_predicate = _is_number 498 elif _is_date_literal(r): 499 a_predicate = _is_date_literal 500 b_predicate = _is_interval 501 else: 502 return expression 503 504 if l.__class__ in INVERSE_DATE_OPS: 505 l = t.cast(exp.IntervalOp, l) 506 a = l.this 507 b = l.interval() 508 else: 509 l = t.cast(exp.Binary, l) 510 a, b = l.left, l.right 511 512 if not a_predicate(a) and b_predicate(b): 513 pass 514 elif not a_predicate(b) and b_predicate(a): 515 a, b = b, a 516 else: 517 return expression 518 519 return expression.__class__( 520 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 521 ) 522 return expression 523 524 525def simplify_literals(expression, root=True): 526 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 527 return _flat_simplify(expression, _simplify_binary, root) 528 529 if isinstance(expression, exp.Neg): 530 this = expression.this 531 if this.is_number: 532 value = this.name 533 if value[0] == "-": 534 return exp.Literal.number(value[1:]) 535 return exp.Literal.number(f"-{value}") 536 537 if type(expression) in INVERSE_DATE_OPS: 538 return _simplify_binary(expression, expression.this, expression.interval()) or expression 539 540 return expression 541 542 543def _simplify_binary(expression, a, b): 544 if isinstance(expression, exp.Is): 545 if isinstance(b, exp.Not): 546 c = b.this 547 not_ = True 548 else: 549 c = b 550 not_ = False 551 552 if is_null(c): 553 if isinstance(a, exp.Literal): 554 return exp.true() if not_ else exp.false() 555 if is_null(a): 556 return exp.false() if not_ else exp.true() 557 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 558 return None 559 elif is_null(a) or is_null(b): 560 return exp.null() 561 562 if a.is_number and b.is_number: 563 num_a = int(a.name) if a.is_int else Decimal(a.name) 564 num_b = int(b.name) if b.is_int else Decimal(b.name) 565 566 if isinstance(expression, exp.Add): 567 return exp.Literal.number(num_a + num_b) 568 if isinstance(expression, exp.Mul): 569 return exp.Literal.number(num_a * num_b) 570 571 # We only simplify Sub, Div if a and b have the same parent because they're not associative 572 if isinstance(expression, exp.Sub): 573 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 574 if isinstance(expression, exp.Div): 575 # engines have differing int div behavior so intdiv is not safe 576 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 577 return None 578 return exp.Literal.number(num_a / num_b) 579 580 boolean = eval_boolean(expression, num_a, num_b) 581 582 if boolean: 583 return boolean 584 elif a.is_string and b.is_string: 585 boolean = eval_boolean(expression, a.this, b.this) 586 587 if boolean: 588 return boolean 589 elif _is_date_literal(a) and isinstance(b, exp.Interval): 590 a, b = extract_date(a), extract_interval(b) 591 if a and b: 592 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 593 return date_literal(a + b) 594 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 595 return date_literal(a - b) 596 elif isinstance(a, exp.Interval) and _is_date_literal(b): 597 a, b = extract_interval(a), extract_date(b) 598 # you cannot subtract a date from an interval 599 if a and b and isinstance(expression, exp.Add): 600 return date_literal(a + b) 601 elif _is_date_literal(a) and _is_date_literal(b): 602 if isinstance(expression, exp.Predicate): 603 a, b = extract_date(a), extract_date(b) 604 boolean = eval_boolean(expression, a, b) 605 if boolean: 606 return boolean 607 608 return None 609 610 611def simplify_parens(expression): 612 if not isinstance(expression, exp.Paren): 613 return expression 614 615 this = expression.this 616 parent = expression.parent 617 618 if not isinstance(this, exp.Select) and ( 619 not isinstance(parent, (exp.Condition, exp.Binary)) 620 or isinstance(parent, exp.Paren) 621 or not isinstance(this, exp.Binary) 622 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 623 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 624 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 625 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 626 ): 627 return this 628 return expression 629 630 631NONNULL_CONSTANTS = ( 632 exp.Literal, 633 exp.Boolean, 634) 635 636CONSTANTS = ( 637 exp.Literal, 638 exp.Boolean, 639 exp.Null, 640) 641 642 643def _is_nonnull_constant(expression: exp.Expression) -> bool: 644 return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) 645 646 647def _is_constant(expression: exp.Expression) -> bool: 648 return isinstance(expression, CONSTANTS) or _is_date_literal(expression) 649 650 651def simplify_coalesce(expression): 652 # COALESCE(x) -> x 653 if ( 654 isinstance(expression, exp.Coalesce) 655 and (not expression.expressions or _is_nonnull_constant(expression.this)) 656 # COALESCE is also used as a Spark partitioning hint 657 and not isinstance(expression.parent, exp.Hint) 658 ): 659 return expression.this 660 661 if not isinstance(expression, COMPARISONS): 662 return expression 663 664 if isinstance(expression.left, exp.Coalesce): 665 coalesce = expression.left 666 other = expression.right 667 elif isinstance(expression.right, exp.Coalesce): 668 coalesce = expression.right 669 other = expression.left 670 else: 671 return expression 672 673 # This transformation is valid for non-constants, 674 # but it really only does anything if they are both constants. 675 if not _is_constant(other): 676 return expression 677 678 # Find the first constant arg 679 for arg_index, arg in enumerate(coalesce.expressions): 680 if _is_constant(arg): 681 break 682 else: 683 return expression 684 685 coalesce.set("expressions", coalesce.expressions[:arg_index]) 686 687 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 688 # since we already remove COALESCE at the top of this function. 689 coalesce = coalesce if coalesce.expressions else coalesce.this 690 691 # This expression is more complex than when we started, but it will get simplified further 692 return exp.paren( 693 exp.or_( 694 exp.and_( 695 coalesce.is_(exp.null()).not_(copy=False), 696 expression.copy(), 697 copy=False, 698 ), 699 exp.and_( 700 coalesce.is_(exp.null()), 701 type(expression)(this=arg.copy(), expression=other.copy()), 702 copy=False, 703 ), 704 copy=False, 705 ) 706 ) 707 708 709CONCATS = (exp.Concat, exp.DPipe) 710 711 712def simplify_concat(expression): 713 """Reduces all groups that contain string literals by concatenating them.""" 714 if not isinstance(expression, CONCATS) or ( 715 # We can't reduce a CONCAT_WS call if we don't statically know the separator 716 isinstance(expression, exp.ConcatWs) 717 and not expression.expressions[0].is_string 718 ): 719 return expression 720 721 if isinstance(expression, exp.ConcatWs): 722 sep_expr, *expressions = expression.expressions 723 sep = sep_expr.name 724 concat_type = exp.ConcatWs 725 args = {} 726 else: 727 expressions = expression.expressions 728 sep = "" 729 concat_type = exp.Concat 730 args = { 731 "safe": expression.args.get("safe"), 732 "coalesce": expression.args.get("coalesce"), 733 } 734 735 new_args = [] 736 for is_string_group, group in itertools.groupby( 737 expressions or expression.flatten(), lambda e: e.is_string 738 ): 739 if is_string_group: 740 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 741 else: 742 new_args.extend(group) 743 744 if len(new_args) == 1 and new_args[0].is_string: 745 return new_args[0] 746 747 if concat_type is exp.ConcatWs: 748 new_args = [sep_expr] + new_args 749 750 return concat_type(expressions=new_args, **args) 751 752 753def simplify_conditionals(expression): 754 """Simplifies expressions like IF, CASE if their condition is statically known.""" 755 if isinstance(expression, exp.Case): 756 this = expression.this 757 for case in expression.args["ifs"]: 758 cond = case.this 759 if this: 760 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 761 cond = cond.replace(this.pop().eq(cond)) 762 763 if always_true(cond): 764 return case.args["true"] 765 766 if always_false(cond): 767 case.pop() 768 if not expression.args["ifs"]: 769 return expression.args.get("default") or exp.null() 770 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 771 if always_true(expression.this): 772 return expression.args["true"] 773 if always_false(expression.this): 774 return expression.args.get("false") or exp.null() 775 776 return expression 777 778 779DateRange = t.Tuple[datetime.date, datetime.date] 780 781 782def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 783 """ 784 Get the date range for a DATE_TRUNC equality comparison: 785 786 Example: 787 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 788 Returns: 789 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 790 """ 791 floor = date_floor(date, unit, dialect) 792 793 if date != floor: 794 # This will always be False, except for NULL values. 795 return None 796 797 return floor, floor + interval(unit) 798 799 800def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 801 """Get the logical expression for a date range""" 802 return exp.and_( 803 left >= date_literal(drange[0]), 804 left < date_literal(drange[1]), 805 copy=False, 806 ) 807 808 809def _datetrunc_eq( 810 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 811) -> t.Optional[exp.Expression]: 812 drange = _datetrunc_range(date, unit, dialect) 813 if not drange: 814 return None 815 816 return _datetrunc_eq_expression(left, drange) 817 818 819def _datetrunc_neq( 820 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 821) -> t.Optional[exp.Expression]: 822 drange = _datetrunc_range(date, unit, dialect) 823 if not drange: 824 return None 825 826 return exp.and_( 827 left < date_literal(drange[0]), 828 left >= date_literal(drange[1]), 829 copy=False, 830 ) 831 832 833DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 834 exp.LT: lambda l, dt, u, d: l 835 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 836 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 837 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 838 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 839 exp.EQ: _datetrunc_eq, 840 exp.NEQ: _datetrunc_neq, 841} 842DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 843DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 844 845 846def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 847 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 848 849 850@catch(ModuleNotFoundError, UnsupportedUnit) 851def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 852 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 853 comparison = expression.__class__ 854 855 if isinstance(expression, DATETRUNCS): 856 date = extract_date(expression.this) 857 if date and expression.unit: 858 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 859 elif comparison not in DATETRUNC_COMPARISONS: 860 return expression 861 862 if isinstance(expression, exp.Binary): 863 l, r = expression.left, expression.right 864 865 if not _is_datetrunc_predicate(l, r): 866 return expression 867 868 l = t.cast(exp.DateTrunc, l) 869 unit = l.unit.name.lower() 870 date = extract_date(r) 871 872 if not date: 873 return expression 874 875 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 876 elif isinstance(expression, exp.In): 877 l = expression.this 878 rs = expression.expressions 879 880 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 881 l = t.cast(exp.DateTrunc, l) 882 unit = l.unit.name.lower() 883 884 ranges = [] 885 for r in rs: 886 date = extract_date(r) 887 if not date: 888 return expression 889 drange = _datetrunc_range(date, unit, dialect) 890 if drange: 891 ranges.append(drange) 892 893 if not ranges: 894 return expression 895 896 ranges = merge_ranges(ranges) 897 898 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 899 900 return expression 901 902 903def sort_comparison(expression: exp.Expression) -> exp.Expression: 904 if expression.__class__ in COMPLEMENT_COMPARISONS: 905 l, r = expression.this, expression.expression 906 l_column = isinstance(l, exp.Column) 907 r_column = isinstance(r, exp.Column) 908 l_const = _is_constant(l) 909 r_const = _is_constant(r) 910 911 if (l_column and not r_column) or (r_const and not l_const): 912 return expression 913 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 914 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 915 this=r, expression=l 916 ) 917 return expression 918 919 920# CROSS joins result in an empty table if the right table is empty. 921# So we can only simplify certain types of joins to CROSS. 922# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 923JOINS = { 924 ("", ""), 925 ("", "INNER"), 926 ("RIGHT", ""), 927 ("RIGHT", "OUTER"), 928} 929 930 931def remove_where_true(expression): 932 for where in expression.find_all(exp.Where): 933 if always_true(where.this): 934 where.parent.set("where", None) 935 for join in expression.find_all(exp.Join): 936 if ( 937 always_true(join.args.get("on")) 938 and not join.args.get("using") 939 and not join.args.get("method") 940 and (join.side, join.kind) in JOINS 941 ): 942 join.set("on", None) 943 join.set("side", None) 944 join.set("kind", "CROSS") 945 946 947def always_true(expression): 948 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 949 expression, exp.Literal 950 ) 951 952 953def always_false(expression): 954 return is_false(expression) or is_null(expression) 955 956 957def is_complement(a, b): 958 return isinstance(b, exp.Not) and b.this == a 959 960 961def is_false(a: exp.Expression) -> bool: 962 return type(a) is exp.Boolean and not a.this 963 964 965def is_null(a: exp.Expression) -> bool: 966 return type(a) is exp.Null 967 968 969def eval_boolean(expression, a, b): 970 if isinstance(expression, (exp.EQ, exp.Is)): 971 return boolean_literal(a == b) 972 if isinstance(expression, exp.NEQ): 973 return boolean_literal(a != b) 974 if isinstance(expression, exp.GT): 975 return boolean_literal(a > b) 976 if isinstance(expression, exp.GTE): 977 return boolean_literal(a >= b) 978 if isinstance(expression, exp.LT): 979 return boolean_literal(a < b) 980 if isinstance(expression, exp.LTE): 981 return boolean_literal(a <= b) 982 return None 983 984 985def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 986 if isinstance(value, datetime.datetime): 987 return value.date() 988 if isinstance(value, datetime.date): 989 return value 990 try: 991 return datetime.datetime.fromisoformat(value).date() 992 except ValueError: 993 return None 994 995 996def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 997 if isinstance(value, datetime.datetime): 998 return value 999 if isinstance(value, datetime.date): 1000 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1001 try: 1002 return datetime.datetime.fromisoformat(value) 1003 except ValueError: 1004 return None 1005 1006 1007def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1008 if not value: 1009 return None 1010 if to.is_type(exp.DataType.Type.DATE): 1011 return cast_as_date(value) 1012 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1013 return cast_as_datetime(value) 1014 return None 1015 1016 1017def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1018 if isinstance(cast, exp.Cast): 1019 to = cast.to 1020 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1021 to = exp.DataType.build(exp.DataType.Type.DATE) 1022 else: 1023 return None 1024 1025 if isinstance(cast.this, exp.Literal): 1026 value: t.Any = cast.this.name 1027 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1028 value = extract_date(cast.this) 1029 else: 1030 return None 1031 return cast_value(value, to) 1032 1033 1034def _is_date_literal(expression: exp.Expression) -> bool: 1035 return extract_date(expression) is not None 1036 1037 1038def extract_interval(expression): 1039 try: 1040 n = int(expression.name) 1041 unit = expression.text("unit").lower() 1042 return interval(unit, n) 1043 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1044 return None 1045 1046 1047def date_literal(date): 1048 return exp.cast( 1049 exp.Literal.string(date), 1050 exp.DataType.Type.DATETIME 1051 if isinstance(date, datetime.datetime) 1052 else exp.DataType.Type.DATE, 1053 ) 1054 1055 1056def interval(unit: str, n: int = 1): 1057 from dateutil.relativedelta import relativedelta 1058 1059 if unit == "year": 1060 return relativedelta(years=1 * n) 1061 if unit == "quarter": 1062 return relativedelta(months=3 * n) 1063 if unit == "month": 1064 return relativedelta(months=1 * n) 1065 if unit == "week": 1066 return relativedelta(weeks=1 * n) 1067 if unit == "day": 1068 return relativedelta(days=1 * n) 1069 if unit == "hour": 1070 return relativedelta(hours=1 * n) 1071 if unit == "minute": 1072 return relativedelta(minutes=1 * n) 1073 if unit == "second": 1074 return relativedelta(seconds=1 * n) 1075 1076 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1077 1078 1079def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1080 if unit == "year": 1081 return d.replace(month=1, day=1) 1082 if unit == "quarter": 1083 if d.month <= 3: 1084 return d.replace(month=1, day=1) 1085 elif d.month <= 6: 1086 return d.replace(month=4, day=1) 1087 elif d.month <= 9: 1088 return d.replace(month=7, day=1) 1089 else: 1090 return d.replace(month=10, day=1) 1091 if unit == "month": 1092 return d.replace(month=d.month, day=1) 1093 if unit == "week": 1094 # Assuming week starts on Monday (0) and ends on Sunday (6) 1095 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1096 if unit == "day": 1097 return d 1098 1099 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1100 1101 1102def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1103 floor = date_floor(d, unit, dialect) 1104 1105 if floor == d: 1106 return d 1107 1108 return floor + interval(unit) 1109 1110 1111def boolean_literal(condition): 1112 return exp.true() if condition else exp.false() 1113 1114 1115def _flat_simplify(expression, simplifier, root=True): 1116 if root or not expression.same_parent: 1117 operands = [] 1118 queue = deque(expression.flatten(unnest=False)) 1119 size = len(queue) 1120 1121 while queue: 1122 a = queue.popleft() 1123 1124 for b in queue: 1125 result = simplifier(expression, a, b) 1126 1127 if result and result is not expression: 1128 queue.remove(b) 1129 queue.appendleft(result) 1130 break 1131 else: 1132 operands.append(a) 1133 1134 if len(operands) < size: 1135 return functools.reduce( 1136 lambda a, b: expression.__class__(this=a, expression=b), operands 1137 ) 1138 return expression 1139 1140 1141def gen(expression: t.Any) -> str: 1142 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1143 1144 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1145 generator is expensive so we have a bare minimum sql generator here. 1146 """ 1147 if expression is None: 1148 return "_" 1149 if is_iterable(expression): 1150 return ",".join(gen(e) for e in expression) 1151 if not isinstance(expression, exp.Expression): 1152 return str(expression) 1153 1154 etype = type(expression) 1155 if etype in GEN_MAP: 1156 return GEN_MAP[etype](expression) 1157 return f"{expression.key} {gen(expression.args.values())}" 1158 1159 1160GEN_MAP = { 1161 exp.Add: lambda e: _binary(e, "+"), 1162 exp.And: lambda e: _binary(e, "AND"), 1163 exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}", 1164 exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", 1165 exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", 1166 exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", 1167 exp.Column: lambda e: ".".join(gen(p) for p in e.parts), 1168 exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", 1169 exp.Div: lambda e: _binary(e, "/"), 1170 exp.Dot: lambda e: _binary(e, "."), 1171 exp.EQ: lambda e: _binary(e, "="), 1172 exp.GT: lambda e: _binary(e, ">"), 1173 exp.GTE: lambda e: _binary(e, ">="), 1174 exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, 1175 exp.ILike: lambda e: _binary(e, "ILIKE"), 1176 exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", 1177 exp.Is: lambda e: _binary(e, "IS"), 1178 exp.Like: lambda e: _binary(e, "LIKE"), 1179 exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, 1180 exp.LT: lambda e: _binary(e, "<"), 1181 exp.LTE: lambda e: _binary(e, "<="), 1182 exp.Mod: lambda e: _binary(e, "%"), 1183 exp.Mul: lambda e: _binary(e, "*"), 1184 exp.Neg: lambda e: _unary(e, "-"), 1185 exp.NEQ: lambda e: _binary(e, "<>"), 1186 exp.Not: lambda e: _unary(e, "NOT"), 1187 exp.Null: lambda e: "NULL", 1188 exp.Or: lambda e: _binary(e, "OR"), 1189 exp.Paren: lambda e: f"({gen(e.this)})", 1190 exp.Sub: lambda e: _binary(e, "-"), 1191 exp.Subquery: lambda e: f"({gen(e.args.values())})", 1192 exp.Table: lambda e: gen(e.args.values()), 1193 exp.Var: lambda e: e.name, 1194} 1195 1196 1197def _binary(e: exp.Binary, op: str) -> str: 1198 return f"{gen(e.left)} {op} {gen(e.right)}" 1199 1200 1201def _unary(e: exp.Unary, op: str) -> str: 1202 return f"{op} {gen(e.this)}"
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
31def simplify( 32 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 33): 34 """ 35 Rewrite sqlglot AST to simplify expressions. 36 37 Example: 38 >>> import sqlglot 39 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 40 >>> simplify(expression).sql() 41 'TRUE' 42 43 Args: 44 expression (sqlglot.Expression): expression to simplify 45 constant_propagation: whether or not the constant propagation rule should be used 46 47 Returns: 48 sqlglot.Expression: simplified expression 49 """ 50 51 dialect = Dialect.get_or_raise(dialect) 52 53 def _simplify(expression, root=True): 54 if expression.meta.get(FINAL): 55 return expression 56 57 # group by expressions cannot be simplified, for example 58 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 59 # the projection must exactly match the group by key 60 group = expression.args.get("group") 61 62 if group and hasattr(expression, "selects"): 63 groups = set(group.expressions) 64 group.meta[FINAL] = True 65 66 for e in expression.selects: 67 for node, *_ in e.walk(): 68 if node in groups: 69 e.meta[FINAL] = True 70 break 71 72 having = expression.args.get("having") 73 if having: 74 for node, *_ in having.walk(): 75 if node in groups: 76 having.meta[FINAL] = True 77 break 78 79 # Pre-order transformations 80 node = expression 81 node = rewrite_between(node) 82 node = uniq_sort(node, root) 83 node = absorb_and_eliminate(node, root) 84 node = simplify_concat(node) 85 node = simplify_conditionals(node) 86 87 if constant_propagation: 88 node = propagate_constants(node, root) 89 90 exp.replace_children(node, lambda e: _simplify(e, False)) 91 92 # Post-order transformations 93 node = simplify_not(node) 94 node = flatten(node) 95 node = simplify_connectors(node, root) 96 node = remove_complements(node, root) 97 node = simplify_coalesce(node) 98 node.parent = expression.parent 99 node = simplify_literals(node, root) 100 node = simplify_equality(node) 101 node = simplify_parens(node) 102 node = simplify_datetrunc(node, dialect) 103 node = sort_comparison(node) 104 105 if root: 106 expression.replace(node) 107 108 return node 109 110 expression = while_changing(expression, _simplify) 111 remove_where_true(expression) 112 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
115def catch(*exceptions): 116 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 117 118 def decorator(func): 119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression 124 125 return wrapped 126 127 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
130def rewrite_between(expression: exp.Expression) -> exp.Expression: 131 """Rewrite x between y and z to x >= y AND x <= z. 132 133 This is done because comparison simplification is only done on lt/lte/gt/gte. 134 """ 135 if isinstance(expression, exp.Between): 136 negate = isinstance(expression.parent, exp.Not) 137 138 expression = exp.and_( 139 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 140 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 141 copy=False, 142 ) 143 144 if negate: 145 expression = exp.paren(expression, copy=False) 146 147 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
160def simplify_not(expression): 161 """ 162 Demorgan's Law 163 NOT (x OR y) -> NOT x AND NOT y 164 NOT (x AND y) -> NOT x OR NOT y 165 """ 166 if isinstance(expression, exp.Not): 167 this = expression.this 168 if is_null(this): 169 return exp.null() 170 if this.__class__ in COMPLEMENT_COMPARISONS: 171 return COMPLEMENT_COMPARISONS[this.__class__]( 172 this=this.this, expression=this.expression 173 ) 174 if isinstance(this, exp.Paren): 175 condition = this.unnest() 176 if isinstance(condition, exp.And): 177 return exp.or_( 178 exp.not_(condition.left, copy=False), 179 exp.not_(condition.right, copy=False), 180 copy=False, 181 ) 182 if isinstance(condition, exp.Or): 183 return exp.and_( 184 exp.not_(condition.left, copy=False), 185 exp.not_(condition.right, copy=False), 186 copy=False, 187 ) 188 if is_null(condition): 189 return exp.null() 190 if always_true(this): 191 return exp.false() 192 if is_false(this): 193 return exp.true() 194 if isinstance(this, exp.Not): 195 # double negation 196 # NOT NOT x -> x 197 return this.this 198 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
201def flatten(expression): 202 """ 203 A AND (B AND C) -> A AND B AND C 204 A OR (B OR C) -> A OR B OR C 205 """ 206 if isinstance(expression, exp.Connector): 207 for node in expression.args.values(): 208 child = node.unnest() 209 if isinstance(child, expression.__class__): 210 node.replace(child) 211 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
214def simplify_connectors(expression, root=True): 215 def _simplify_connectors(expression, left, right): 216 if left == right: 217 return left 218 if isinstance(expression, exp.And): 219 if is_false(left) or is_false(right): 220 return exp.false() 221 if is_null(left) or is_null(right): 222 return exp.null() 223 if always_true(left) and always_true(right): 224 return exp.true() 225 if always_true(left): 226 return right 227 if always_true(right): 228 return left 229 return _simplify_comparison(expression, left, right) 230 elif isinstance(expression, exp.Or): 231 if always_true(left) or always_true(right): 232 return exp.true() 233 if is_false(left) and is_false(right): 234 return exp.false() 235 if ( 236 (is_null(left) and is_null(right)) 237 or (is_null(left) and is_false(right)) 238 or (is_false(left) and is_null(right)) 239 ): 240 return exp.null() 241 if is_false(left): 242 return right 243 if is_false(right): 244 return left 245 return _simplify_comparison(expression, left, right, or_=True) 246 247 if isinstance(expression, exp.Connector): 248 return _flat_simplify(expression, _simplify_connectors, root) 249 return expression
333def remove_complements(expression, root=True): 334 """ 335 Removing complements. 336 337 A AND NOT A -> FALSE 338 A OR NOT A -> TRUE 339 """ 340 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 341 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 342 343 for a, b in itertools.permutations(expression.flatten(), 2): 344 if is_complement(a, b): 345 return complement 346 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
349def uniq_sort(expression, root=True): 350 """ 351 Uniq and sort a connector. 352 353 C AND A AND B AND B -> A AND B AND C 354 """ 355 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 356 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 357 flattened = tuple(expression.flatten()) 358 deduped = {gen(e): e for e in flattened} 359 arr = tuple(deduped.items()) 360 361 # check if the operands are already sorted, if not sort them 362 # A AND C AND B -> A AND B AND C 363 for i, (sql, e) in enumerate(arr[1:]): 364 if sql < arr[i][0]: 365 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 366 break 367 else: 368 # we didn't have to sort but maybe we need to dedup 369 if len(deduped) < len(flattened): 370 expression = result_func(*deduped.values(), copy=False) 371 372 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
375def absorb_and_eliminate(expression, root=True): 376 """ 377 absorption: 378 A AND (A OR B) -> A 379 A OR (A AND B) -> A 380 A AND (NOT A OR B) -> A AND B 381 A OR (NOT A AND B) -> A OR B 382 elimination: 383 (A AND B) OR (A AND NOT B) -> A 384 (A OR B) AND (A OR NOT B) -> A 385 """ 386 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 387 kind = exp.Or if isinstance(expression, exp.And) else exp.And 388 389 for a, b in itertools.permutations(expression.flatten(), 2): 390 if isinstance(a, kind): 391 aa, ab = a.unnest_operands() 392 393 # absorb 394 if is_complement(b, aa): 395 aa.replace(exp.true() if kind == exp.And else exp.false()) 396 elif is_complement(b, ab): 397 ab.replace(exp.true() if kind == exp.And else exp.false()) 398 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 399 a.replace(exp.false() if kind == exp.And else exp.true()) 400 elif isinstance(b, kind): 401 # eliminate 402 rhs = b.unnest_operands() 403 ba, bb = rhs 404 405 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 406 a.replace(aa) 407 b.replace(aa) 408 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 409 a.replace(ab) 410 b.replace(ab) 411 412 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
415def propagate_constants(expression, root=True): 416 """ 417 Propagate constants for conjunctions in DNF: 418 419 SELECT * FROM t WHERE a = b AND b = 5 becomes 420 SELECT * FROM t WHERE a = 5 AND b = 5 421 422 Reference: https://www.sqlite.org/optoverview.html 423 """ 424 425 if ( 426 isinstance(expression, exp.And) 427 and (root or not expression.same_parent) 428 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 429 ): 430 constant_mapping = {} 431 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 432 if isinstance(expr, exp.EQ): 433 l, r = expr.left, expr.right 434 435 # TODO: create a helper that can be used to detect nested literal expressions such 436 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 437 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 438 constant_mapping[l] = (id(l), r) 439 440 if constant_mapping: 441 for column in find_all_in_scope(expression, exp.Column): 442 parent = column.parent 443 column_id, constant = constant_mapping.get(column) or (None, None) 444 if ( 445 column_id is not None 446 and id(column) != column_id 447 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 448 ): 449 column.replace(constant.copy()) 450 451 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
526def simplify_literals(expression, root=True): 527 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 528 return _flat_simplify(expression, _simplify_binary, root) 529 530 if isinstance(expression, exp.Neg): 531 this = expression.this 532 if this.is_number: 533 value = this.name 534 if value[0] == "-": 535 return exp.Literal.number(value[1:]) 536 return exp.Literal.number(f"-{value}") 537 538 if type(expression) in INVERSE_DATE_OPS: 539 return _simplify_binary(expression, expression.this, expression.interval()) or expression 540 541 return expression
612def simplify_parens(expression): 613 if not isinstance(expression, exp.Paren): 614 return expression 615 616 this = expression.this 617 parent = expression.parent 618 619 if not isinstance(this, exp.Select) and ( 620 not isinstance(parent, (exp.Condition, exp.Binary)) 621 or isinstance(parent, exp.Paren) 622 or not isinstance(this, exp.Binary) 623 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 624 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 625 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 626 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 627 ): 628 return this 629 return expression
652def simplify_coalesce(expression): 653 # COALESCE(x) -> x 654 if ( 655 isinstance(expression, exp.Coalesce) 656 and (not expression.expressions or _is_nonnull_constant(expression.this)) 657 # COALESCE is also used as a Spark partitioning hint 658 and not isinstance(expression.parent, exp.Hint) 659 ): 660 return expression.this 661 662 if not isinstance(expression, COMPARISONS): 663 return expression 664 665 if isinstance(expression.left, exp.Coalesce): 666 coalesce = expression.left 667 other = expression.right 668 elif isinstance(expression.right, exp.Coalesce): 669 coalesce = expression.right 670 other = expression.left 671 else: 672 return expression 673 674 # This transformation is valid for non-constants, 675 # but it really only does anything if they are both constants. 676 if not _is_constant(other): 677 return expression 678 679 # Find the first constant arg 680 for arg_index, arg in enumerate(coalesce.expressions): 681 if _is_constant(arg): 682 break 683 else: 684 return expression 685 686 coalesce.set("expressions", coalesce.expressions[:arg_index]) 687 688 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 689 # since we already remove COALESCE at the top of this function. 690 coalesce = coalesce if coalesce.expressions else coalesce.this 691 692 # This expression is more complex than when we started, but it will get simplified further 693 return exp.paren( 694 exp.or_( 695 exp.and_( 696 coalesce.is_(exp.null()).not_(copy=False), 697 expression.copy(), 698 copy=False, 699 ), 700 exp.and_( 701 coalesce.is_(exp.null()), 702 type(expression)(this=arg.copy(), expression=other.copy()), 703 copy=False, 704 ), 705 copy=False, 706 ) 707 )
713def simplify_concat(expression): 714 """Reduces all groups that contain string literals by concatenating them.""" 715 if not isinstance(expression, CONCATS) or ( 716 # We can't reduce a CONCAT_WS call if we don't statically know the separator 717 isinstance(expression, exp.ConcatWs) 718 and not expression.expressions[0].is_string 719 ): 720 return expression 721 722 if isinstance(expression, exp.ConcatWs): 723 sep_expr, *expressions = expression.expressions 724 sep = sep_expr.name 725 concat_type = exp.ConcatWs 726 args = {} 727 else: 728 expressions = expression.expressions 729 sep = "" 730 concat_type = exp.Concat 731 args = { 732 "safe": expression.args.get("safe"), 733 "coalesce": expression.args.get("coalesce"), 734 } 735 736 new_args = [] 737 for is_string_group, group in itertools.groupby( 738 expressions or expression.flatten(), lambda e: e.is_string 739 ): 740 if is_string_group: 741 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 742 else: 743 new_args.extend(group) 744 745 if len(new_args) == 1 and new_args[0].is_string: 746 return new_args[0] 747 748 if concat_type is exp.ConcatWs: 749 new_args = [sep_expr] + new_args 750 751 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
754def simplify_conditionals(expression): 755 """Simplifies expressions like IF, CASE if their condition is statically known.""" 756 if isinstance(expression, exp.Case): 757 this = expression.this 758 for case in expression.args["ifs"]: 759 cond = case.this 760 if this: 761 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 762 cond = cond.replace(this.pop().eq(cond)) 763 764 if always_true(cond): 765 return case.args["true"] 766 767 if always_false(cond): 768 case.pop() 769 if not expression.args["ifs"]: 770 return expression.args.get("default") or exp.null() 771 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 772 if always_true(expression.this): 773 return expression.args["true"] 774 if always_false(expression.this): 775 return expression.args.get("false") or exp.null() 776 777 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
904def sort_comparison(expression: exp.Expression) -> exp.Expression: 905 if expression.__class__ in COMPLEMENT_COMPARISONS: 906 l, r = expression.this, expression.expression 907 l_column = isinstance(l, exp.Column) 908 r_column = isinstance(r, exp.Column) 909 l_const = _is_constant(l) 910 r_const = _is_constant(r) 911 912 if (l_column and not r_column) or (r_const and not l_const): 913 return expression 914 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 915 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 916 this=r, expression=l 917 ) 918 return expression
932def remove_where_true(expression): 933 for where in expression.find_all(exp.Where): 934 if always_true(where.this): 935 where.parent.set("where", None) 936 for join in expression.find_all(exp.Join): 937 if ( 938 always_true(join.args.get("on")) 939 and not join.args.get("using") 940 and not join.args.get("method") 941 and (join.side, join.kind) in JOINS 942 ): 943 join.set("on", None) 944 join.set("side", None) 945 join.set("kind", "CROSS")
970def eval_boolean(expression, a, b): 971 if isinstance(expression, (exp.EQ, exp.Is)): 972 return boolean_literal(a == b) 973 if isinstance(expression, exp.NEQ): 974 return boolean_literal(a != b) 975 if isinstance(expression, exp.GT): 976 return boolean_literal(a > b) 977 if isinstance(expression, exp.GTE): 978 return boolean_literal(a >= b) 979 if isinstance(expression, exp.LT): 980 return boolean_literal(a < b) 981 if isinstance(expression, exp.LTE): 982 return boolean_literal(a <= b) 983 return None
997def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 998 if isinstance(value, datetime.datetime): 999 return value 1000 if isinstance(value, datetime.date): 1001 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1002 try: 1003 return datetime.datetime.fromisoformat(value) 1004 except ValueError: 1005 return None
1008def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1009 if not value: 1010 return None 1011 if to.is_type(exp.DataType.Type.DATE): 1012 return cast_as_date(value) 1013 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1014 return cast_as_datetime(value) 1015 return None
1018def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1019 if isinstance(cast, exp.Cast): 1020 to = cast.to 1021 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1022 to = exp.DataType.build(exp.DataType.Type.DATE) 1023 else: 1024 return None 1025 1026 if isinstance(cast.this, exp.Literal): 1027 value: t.Any = cast.this.name 1028 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1029 value = extract_date(cast.this) 1030 else: 1031 return None 1032 return cast_value(value, to)
1057def interval(unit: str, n: int = 1): 1058 from dateutil.relativedelta import relativedelta 1059 1060 if unit == "year": 1061 return relativedelta(years=1 * n) 1062 if unit == "quarter": 1063 return relativedelta(months=3 * n) 1064 if unit == "month": 1065 return relativedelta(months=1 * n) 1066 if unit == "week": 1067 return relativedelta(weeks=1 * n) 1068 if unit == "day": 1069 return relativedelta(days=1 * n) 1070 if unit == "hour": 1071 return relativedelta(hours=1 * n) 1072 if unit == "minute": 1073 return relativedelta(minutes=1 * n) 1074 if unit == "second": 1075 return relativedelta(seconds=1 * n) 1076 1077 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1080def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1081 if unit == "year": 1082 return d.replace(month=1, day=1) 1083 if unit == "quarter": 1084 if d.month <= 3: 1085 return d.replace(month=1, day=1) 1086 elif d.month <= 6: 1087 return d.replace(month=4, day=1) 1088 elif d.month <= 9: 1089 return d.replace(month=7, day=1) 1090 else: 1091 return d.replace(month=10, day=1) 1092 if unit == "month": 1093 return d.replace(month=d.month, day=1) 1094 if unit == "week": 1095 # Assuming week starts on Monday (0) and ends on Sunday (6) 1096 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1097 if unit == "day": 1098 return d 1099 1100 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1142def gen(expression: t.Any) -> str: 1143 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1144 1145 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1146 generator is expensive so we have a bare minimum sql generator here. 1147 """ 1148 if expression is None: 1149 return "_" 1150 if is_iterable(expression): 1151 return ",".join(gen(e) for e in expression) 1152 if not isinstance(expression, exp.Expression): 1153 return str(expression) 1154 1155 etype = type(expression) 1156 if etype in GEN_MAP: 1157 return GEN_MAP[etype](expression) 1158 return f"{expression.key} {gen(expression.args.values())}"
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.