sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8from sqlglot import exp 9from sqlglot.generator import cached_generator 10from sqlglot.helper import first, merge_ranges, while_changing 11 12# Final means that an expression should not be simplified 13FINAL = "final" 14 15 16class UnsupportedUnit(Exception): 17 pass 18 19 20def simplify(expression): 21 """ 22 Rewrite sqlglot AST to simplify expressions. 23 24 Example: 25 >>> import sqlglot 26 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 27 >>> simplify(expression).sql() 28 'TRUE' 29 30 Args: 31 expression (sqlglot.Expression): expression to simplify 32 Returns: 33 sqlglot.Expression: simplified expression 34 """ 35 36 generate = cached_generator() 37 38 # group by expressions cannot be simplified, for example 39 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 40 # the projection must exactly match the group by key 41 for group in expression.find_all(exp.Group): 42 select = group.parent 43 groups = set(group.expressions) 44 group.meta[FINAL] = True 45 46 for e in select.selects: 47 for node, *_ in e.walk(): 48 if node in groups: 49 e.meta[FINAL] = True 50 break 51 52 having = select.args.get("having") 53 if having: 54 for node, *_ in having.walk(): 55 if node in groups: 56 having.meta[FINAL] = True 57 break 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # Pre-order transformations 64 node = expression 65 node = rewrite_between(node) 66 node = uniq_sort(node, generate, root) 67 node = absorb_and_eliminate(node, root) 68 node = simplify_concat(node) 69 70 exp.replace_children(node, lambda e: _simplify(e, False)) 71 72 # Post-order transformations 73 node = simplify_not(node) 74 node = flatten(node) 75 node = simplify_connectors(node, root) 76 node = remove_compliments(node, root) 77 node = simplify_coalesce(node) 78 node.parent = expression.parent 79 node = simplify_literals(node, root) 80 node = simplify_equality(node) 81 node = simplify_parens(node) 82 node = simplify_datetrunc_predicate(node) 83 84 if root: 85 expression.replace(node) 86 87 return node 88 89 expression = while_changing(expression, _simplify) 90 remove_where_true(expression) 91 return expression 92 93 94def catch(*exceptions): 95 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 96 97 def decorator(func): 98 def wrapped(expression, *args, **kwargs): 99 try: 100 return func(expression, *args, **kwargs) 101 except exceptions: 102 return expression 103 104 return wrapped 105 106 return decorator 107 108 109def rewrite_between(expression: exp.Expression) -> exp.Expression: 110 """Rewrite x between y and z to x >= y AND x <= z. 111 112 This is done because comparison simplification is only done on lt/lte/gt/gte. 113 """ 114 if isinstance(expression, exp.Between): 115 return exp.and_( 116 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 117 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 118 copy=False, 119 ) 120 return expression 121 122 123def simplify_not(expression): 124 """ 125 Demorgan's Law 126 NOT (x OR y) -> NOT x AND NOT y 127 NOT (x AND y) -> NOT x OR NOT y 128 """ 129 if isinstance(expression, exp.Not): 130 if is_null(expression.this): 131 return exp.null() 132 if isinstance(expression.this, exp.Paren): 133 condition = expression.this.unnest() 134 if isinstance(condition, exp.And): 135 return exp.or_( 136 exp.not_(condition.left, copy=False), 137 exp.not_(condition.right, copy=False), 138 copy=False, 139 ) 140 if isinstance(condition, exp.Or): 141 return exp.and_( 142 exp.not_(condition.left, copy=False), 143 exp.not_(condition.right, copy=False), 144 copy=False, 145 ) 146 if is_null(condition): 147 return exp.null() 148 if always_true(expression.this): 149 return exp.false() 150 if is_false(expression.this): 151 return exp.true() 152 if isinstance(expression.this, exp.Not): 153 # double negation 154 # NOT NOT x -> x 155 return expression.this.this 156 return expression 157 158 159def flatten(expression): 160 """ 161 A AND (B AND C) -> A AND B AND C 162 A OR (B OR C) -> A OR B OR C 163 """ 164 if isinstance(expression, exp.Connector): 165 for node in expression.args.values(): 166 child = node.unnest() 167 if isinstance(child, expression.__class__): 168 node.replace(child) 169 return expression 170 171 172def simplify_connectors(expression, root=True): 173 def _simplify_connectors(expression, left, right): 174 if left == right: 175 return left 176 if isinstance(expression, exp.And): 177 if is_false(left) or is_false(right): 178 return exp.false() 179 if is_null(left) or is_null(right): 180 return exp.null() 181 if always_true(left) and always_true(right): 182 return exp.true() 183 if always_true(left): 184 return right 185 if always_true(right): 186 return left 187 return _simplify_comparison(expression, left, right) 188 elif isinstance(expression, exp.Or): 189 if always_true(left) or always_true(right): 190 return exp.true() 191 if is_false(left) and is_false(right): 192 return exp.false() 193 if ( 194 (is_null(left) and is_null(right)) 195 or (is_null(left) and is_false(right)) 196 or (is_false(left) and is_null(right)) 197 ): 198 return exp.null() 199 if is_false(left): 200 return right 201 if is_false(right): 202 return left 203 return _simplify_comparison(expression, left, right, or_=True) 204 205 if isinstance(expression, exp.Connector): 206 return _flat_simplify(expression, _simplify_connectors, root) 207 return expression 208 209 210LT_LTE = (exp.LT, exp.LTE) 211GT_GTE = (exp.GT, exp.GTE) 212 213COMPARISONS = ( 214 *LT_LTE, 215 *GT_GTE, 216 exp.EQ, 217 exp.NEQ, 218 exp.Is, 219) 220 221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 222 exp.LT: exp.GT, 223 exp.GT: exp.LT, 224 exp.LTE: exp.GTE, 225 exp.GTE: exp.LTE, 226} 227 228 229def _simplify_comparison(expression, left, right, or_=False): 230 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 231 ll, lr = left.args.values() 232 rl, rr = right.args.values() 233 234 largs = {ll, lr} 235 rargs = {rl, rr} 236 237 matching = largs & rargs 238 columns = {m for m in matching if isinstance(m, exp.Column)} 239 240 if matching and columns: 241 try: 242 l = first(largs - columns) 243 r = first(rargs - columns) 244 except StopIteration: 245 return expression 246 247 # make sure the comparison is always of the form x > 1 instead of 1 < x 248 if left.__class__ in INVERSE_COMPARISONS and l == ll: 249 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 250 if right.__class__ in INVERSE_COMPARISONS and r == rl: 251 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 252 253 if l.is_number and r.is_number: 254 l = float(l.name) 255 r = float(r.name) 256 elif l.is_string and r.is_string: 257 l = l.name 258 r = r.name 259 else: 260 return None 261 262 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 263 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 264 return left if (av > bv if or_ else av <= bv) else right 265 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 266 return left if (av < bv if or_ else av >= bv) else right 267 268 # we can't ever shortcut to true because the column could be null 269 if not or_: 270 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 271 if av <= bv: 272 return exp.false() 273 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 274 if av >= bv: 275 return exp.false() 276 elif isinstance(a, exp.EQ): 277 if isinstance(b, exp.LT): 278 return exp.false() if av >= bv else a 279 if isinstance(b, exp.LTE): 280 return exp.false() if av > bv else a 281 if isinstance(b, exp.GT): 282 return exp.false() if av <= bv else a 283 if isinstance(b, exp.GTE): 284 return exp.false() if av < bv else a 285 if isinstance(b, exp.NEQ): 286 return exp.false() if av == bv else a 287 return None 288 289 290def remove_compliments(expression, root=True): 291 """ 292 Removing compliments. 293 294 A AND NOT A -> FALSE 295 A OR NOT A -> TRUE 296 """ 297 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 298 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 299 300 for a, b in itertools.permutations(expression.flatten(), 2): 301 if is_complement(a, b): 302 return compliment 303 return expression 304 305 306def uniq_sort(expression, generate, root=True): 307 """ 308 Uniq and sort a connector. 309 310 C AND A AND B AND B -> A AND B AND C 311 """ 312 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 313 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 314 flattened = tuple(expression.flatten()) 315 deduped = {generate(e): e for e in flattened} 316 arr = tuple(deduped.items()) 317 318 # check if the operands are already sorted, if not sort them 319 # A AND C AND B -> A AND B AND C 320 for i, (sql, e) in enumerate(arr[1:]): 321 if sql < arr[i][0]: 322 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 323 break 324 else: 325 # we didn't have to sort but maybe we need to dedup 326 if len(deduped) < len(flattened): 327 expression = result_func(*deduped.values(), copy=False) 328 329 return expression 330 331 332def absorb_and_eliminate(expression, root=True): 333 """ 334 absorption: 335 A AND (A OR B) -> A 336 A OR (A AND B) -> A 337 A AND (NOT A OR B) -> A AND B 338 A OR (NOT A AND B) -> A OR B 339 elimination: 340 (A AND B) OR (A AND NOT B) -> A 341 (A OR B) AND (A OR NOT B) -> A 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 kind = exp.Or if isinstance(expression, exp.And) else exp.And 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if isinstance(a, kind): 348 aa, ab = a.unnest_operands() 349 350 # absorb 351 if is_complement(b, aa): 352 aa.replace(exp.true() if kind == exp.And else exp.false()) 353 elif is_complement(b, ab): 354 ab.replace(exp.true() if kind == exp.And else exp.false()) 355 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 356 a.replace(exp.false() if kind == exp.And else exp.true()) 357 elif isinstance(b, kind): 358 # eliminate 359 rhs = b.unnest_operands() 360 ba, bb = rhs 361 362 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 363 a.replace(aa) 364 b.replace(aa) 365 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 366 a.replace(ab) 367 b.replace(ab) 368 369 return expression 370 371 372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 373 exp.DateAdd: exp.Sub, 374 exp.DateSub: exp.Add, 375 exp.DatetimeAdd: exp.Sub, 376 exp.DatetimeSub: exp.Add, 377} 378 379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 380 **INVERSE_DATE_OPS, 381 exp.Add: exp.Sub, 382 exp.Sub: exp.Add, 383} 384 385 386def _is_number(expression: exp.Expression) -> bool: 387 return expression.is_number 388 389 390def _is_date(expression: exp.Expression) -> bool: 391 return isinstance(expression, exp.Cast) and extract_date(expression) is not None 392 393 394def _is_interval(expression: exp.Expression) -> bool: 395 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 396 397 398@catch(ModuleNotFoundError, UnsupportedUnit) 399def simplify_equality(expression: exp.Expression) -> exp.Expression: 400 """ 401 Use the subtraction and addition properties of equality to simplify expressions: 402 403 x + 1 = 3 becomes x = 2 404 405 There are two binary operations in the above expression: + and = 406 Here's how we reference all the operands in the code below: 407 408 l r 409 x + 1 = 3 410 a b 411 """ 412 if isinstance(expression, COMPARISONS): 413 l, r = expression.left, expression.right 414 415 if l.__class__ in INVERSE_OPS: 416 pass 417 elif r.__class__ in INVERSE_OPS: 418 l, r = r, l 419 else: 420 return expression 421 422 if r.is_number: 423 a_predicate = _is_number 424 b_predicate = _is_number 425 elif _is_date(r): 426 a_predicate = _is_date 427 b_predicate = _is_interval 428 else: 429 return expression 430 431 if l.__class__ in INVERSE_DATE_OPS: 432 a = l.this 433 b = exp.Interval( 434 this=l.expression.copy(), 435 unit=l.unit.copy(), 436 ) 437 else: 438 a, b = l.left, l.right 439 440 if not a_predicate(a) and b_predicate(b): 441 pass 442 elif not a_predicate(b) and b_predicate(a): 443 a, b = b, a 444 else: 445 return expression 446 447 return expression.__class__( 448 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 449 ) 450 return expression 451 452 453def simplify_literals(expression, root=True): 454 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 455 return _flat_simplify(expression, _simplify_binary, root) 456 457 if isinstance(expression, exp.Neg): 458 this = expression.this 459 if this.is_number: 460 value = this.name 461 if value[0] == "-": 462 return exp.Literal.number(value[1:]) 463 return exp.Literal.number(f"-{value}") 464 465 return expression 466 467 468def _simplify_binary(expression, a, b): 469 if isinstance(expression, exp.Is): 470 if isinstance(b, exp.Not): 471 c = b.this 472 not_ = True 473 else: 474 c = b 475 not_ = False 476 477 if is_null(c): 478 if isinstance(a, exp.Literal): 479 return exp.true() if not_ else exp.false() 480 if is_null(a): 481 return exp.false() if not_ else exp.true() 482 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 483 return None 484 elif is_null(a) or is_null(b): 485 return exp.null() 486 487 if a.is_number and b.is_number: 488 a = int(a.name) if a.is_int else Decimal(a.name) 489 b = int(b.name) if b.is_int else Decimal(b.name) 490 491 if isinstance(expression, exp.Add): 492 return exp.Literal.number(a + b) 493 if isinstance(expression, exp.Sub): 494 return exp.Literal.number(a - b) 495 if isinstance(expression, exp.Mul): 496 return exp.Literal.number(a * b) 497 if isinstance(expression, exp.Div): 498 # engines have differing int div behavior so intdiv is not safe 499 if isinstance(a, int) and isinstance(b, int): 500 return None 501 return exp.Literal.number(a / b) 502 503 boolean = eval_boolean(expression, a, b) 504 505 if boolean: 506 return boolean 507 elif a.is_string and b.is_string: 508 boolean = eval_boolean(expression, a.this, b.this) 509 510 if boolean: 511 return boolean 512 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 513 a, b = extract_date(a), extract_interval(b) 514 if a and b: 515 if isinstance(expression, exp.Add): 516 return date_literal(a + b) 517 if isinstance(expression, exp.Sub): 518 return date_literal(a - b) 519 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 520 a, b = extract_interval(a), extract_date(b) 521 # you cannot subtract a date from an interval 522 if a and b and isinstance(expression, exp.Add): 523 return date_literal(a + b) 524 525 return None 526 527 528def simplify_parens(expression): 529 if not isinstance(expression, exp.Paren): 530 return expression 531 532 this = expression.this 533 parent = expression.parent 534 535 if not isinstance(this, exp.Select) and ( 536 not isinstance(parent, (exp.Condition, exp.Binary)) 537 or isinstance(parent, exp.Paren) 538 or not isinstance(this, exp.Binary) 539 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 540 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 541 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 542 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 543 ): 544 return this 545 return expression 546 547 548CONSTANTS = ( 549 exp.Literal, 550 exp.Boolean, 551 exp.Null, 552) 553 554 555def simplify_coalesce(expression): 556 # COALESCE(x) -> x 557 if ( 558 isinstance(expression, exp.Coalesce) 559 and not expression.expressions 560 # COALESCE is also used as a Spark partitioning hint 561 and not isinstance(expression.parent, exp.Hint) 562 ): 563 return expression.this 564 565 if not isinstance(expression, COMPARISONS): 566 return expression 567 568 if isinstance(expression.left, exp.Coalesce): 569 coalesce = expression.left 570 other = expression.right 571 elif isinstance(expression.right, exp.Coalesce): 572 coalesce = expression.right 573 other = expression.left 574 else: 575 return expression 576 577 # This transformation is valid for non-constants, 578 # but it really only does anything if they are both constants. 579 if not isinstance(other, CONSTANTS): 580 return expression 581 582 # Find the first constant arg 583 for arg_index, arg in enumerate(coalesce.expressions): 584 if isinstance(arg, CONSTANTS): 585 break 586 else: 587 return expression 588 589 coalesce.set("expressions", coalesce.expressions[:arg_index]) 590 591 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 592 # since we already remove COALESCE at the top of this function. 593 coalesce = coalesce if coalesce.expressions else coalesce.this 594 595 # This expression is more complex than when we started, but it will get simplified further 596 return exp.paren( 597 exp.or_( 598 exp.and_( 599 coalesce.is_(exp.null()).not_(copy=False), 600 expression.copy(), 601 copy=False, 602 ), 603 exp.and_( 604 coalesce.is_(exp.null()), 605 type(expression)(this=arg.copy(), expression=other.copy()), 606 copy=False, 607 ), 608 copy=False, 609 ) 610 ) 611 612 613CONCATS = (exp.Concat, exp.DPipe) 614SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 615 616 617def simplify_concat(expression): 618 """Reduces all groups that contain string literals by concatenating them.""" 619 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 620 return expression 621 622 new_args = [] 623 for is_string_group, group in itertools.groupby( 624 expression.expressions or expression.flatten(), lambda e: e.is_string 625 ): 626 if is_string_group: 627 new_args.append(exp.Literal.string("".join(string.name for string in group))) 628 else: 629 new_args.extend(group) 630 631 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 632 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 633 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 634 635 636DateRange = t.Tuple[datetime.date, datetime.date] 637 638 639def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 640 """ 641 Get the date range for a DATE_TRUNC equality comparison: 642 643 Example: 644 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 645 Returns: 646 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 647 """ 648 floor = date_floor(date, unit) 649 650 if date != floor: 651 # This will always be False, except for NULL values. 652 return None 653 654 return floor, floor + interval(unit) 655 656 657def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 658 """Get the logical expression for a date range""" 659 return exp.and_( 660 left >= date_literal(drange[0]), 661 left < date_literal(drange[1]), 662 copy=False, 663 ) 664 665 666def _datetrunc_eq( 667 left: exp.Expression, date: datetime.date, unit: str 668) -> t.Optional[exp.Expression]: 669 drange = _datetrunc_range(date, unit) 670 if not drange: 671 return None 672 673 return _datetrunc_eq_expression(left, drange) 674 675 676def _datetrunc_neq( 677 left: exp.Expression, date: datetime.date, unit: str 678) -> t.Optional[exp.Expression]: 679 drange = _datetrunc_range(date, unit) 680 if not drange: 681 return None 682 683 return exp.and_( 684 left < date_literal(drange[0]), 685 left >= date_literal(drange[1]), 686 copy=False, 687 ) 688 689 690DateTruncBinaryTransform = t.Callable[ 691 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 692] 693DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 694 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 695 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 696 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 697 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 698 exp.EQ: _datetrunc_eq, 699 exp.NEQ: _datetrunc_neq, 700} 701DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 702 703 704def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 705 return ( 706 isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) 707 and isinstance(right, exp.Cast) 708 and right.is_type(*exp.DataType.TEMPORAL_TYPES) 709 ) 710 711 712@catch(ModuleNotFoundError, UnsupportedUnit) 713def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 714 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 715 comparison = expression.__class__ 716 717 if comparison not in DATETRUNC_COMPARISONS: 718 return expression 719 720 if isinstance(expression, exp.Binary): 721 l, r = expression.left, expression.right 722 723 if _is_datetrunc_predicate(l, r): 724 pass 725 elif _is_datetrunc_predicate(r, l): 726 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 727 l, r = r, l 728 else: 729 return expression 730 731 unit = l.unit.name.lower() 732 date = extract_date(r) 733 734 if not date: 735 return expression 736 737 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 738 elif isinstance(expression, exp.In): 739 l = expression.this 740 rs = expression.expressions 741 742 if all(_is_datetrunc_predicate(l, r) for r in rs): 743 unit = l.unit.name.lower() 744 745 ranges = [] 746 for r in rs: 747 date = extract_date(r) 748 if not date: 749 return expression 750 drange = _datetrunc_range(date, unit) 751 if drange: 752 ranges.append(drange) 753 754 if not ranges: 755 return expression 756 757 ranges = merge_ranges(ranges) 758 759 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 760 761 return expression 762 763 764# CROSS joins result in an empty table if the right table is empty. 765# So we can only simplify certain types of joins to CROSS. 766# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 767JOINS = { 768 ("", ""), 769 ("", "INNER"), 770 ("RIGHT", ""), 771 ("RIGHT", "OUTER"), 772} 773 774 775def remove_where_true(expression): 776 for where in expression.find_all(exp.Where): 777 if always_true(where.this): 778 where.parent.set("where", None) 779 for join in expression.find_all(exp.Join): 780 if ( 781 always_true(join.args.get("on")) 782 and not join.args.get("using") 783 and not join.args.get("method") 784 and (join.side, join.kind) in JOINS 785 ): 786 join.set("on", None) 787 join.set("side", None) 788 join.set("kind", "CROSS") 789 790 791def always_true(expression): 792 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 793 expression, exp.Literal 794 ) 795 796 797def is_complement(a, b): 798 return isinstance(b, exp.Not) and b.this == a 799 800 801def is_false(a: exp.Expression) -> bool: 802 return type(a) is exp.Boolean and not a.this 803 804 805def is_null(a: exp.Expression) -> bool: 806 return type(a) is exp.Null 807 808 809def eval_boolean(expression, a, b): 810 if isinstance(expression, (exp.EQ, exp.Is)): 811 return boolean_literal(a == b) 812 if isinstance(expression, exp.NEQ): 813 return boolean_literal(a != b) 814 if isinstance(expression, exp.GT): 815 return boolean_literal(a > b) 816 if isinstance(expression, exp.GTE): 817 return boolean_literal(a >= b) 818 if isinstance(expression, exp.LT): 819 return boolean_literal(a < b) 820 if isinstance(expression, exp.LTE): 821 return boolean_literal(a <= b) 822 return None 823 824 825def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 826 if isinstance(value, datetime.datetime): 827 return value.date() 828 if isinstance(value, datetime.date): 829 return value 830 try: 831 return datetime.datetime.fromisoformat(value).date() 832 except ValueError: 833 return None 834 835 836def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 837 if isinstance(value, datetime.datetime): 838 return value 839 if isinstance(value, datetime.date): 840 return datetime.datetime(year=value.year, month=value.month, day=value.day) 841 try: 842 return datetime.datetime.fromisoformat(value) 843 except ValueError: 844 return None 845 846 847def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 848 if not value: 849 return None 850 if to.is_type(exp.DataType.Type.DATE): 851 return cast_as_date(value) 852 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 853 return cast_as_datetime(value) 854 return None 855 856 857def extract_date(cast: exp.Cast) -> t.Optional[t.Union[datetime.date, datetime.date]]: 858 value: t.Any 859 if isinstance(cast.this, exp.Literal): 860 value = cast.this.name 861 elif isinstance(cast.this, exp.Cast): 862 value = extract_date(cast.this) 863 else: 864 return None 865 return cast_value(value, cast.to) 866 867 868def extract_interval(expression): 869 n = int(expression.name) 870 unit = expression.text("unit").lower() 871 872 try: 873 return interval(unit, n) 874 except (UnsupportedUnit, ModuleNotFoundError): 875 return None 876 877 878def date_literal(date): 879 return exp.cast( 880 exp.Literal.string(date), 881 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 882 ) 883 884 885def interval(unit: str, n: int = 1): 886 from dateutil.relativedelta import relativedelta 887 888 if unit == "year": 889 return relativedelta(years=1 * n) 890 if unit == "quarter": 891 return relativedelta(months=3 * n) 892 if unit == "month": 893 return relativedelta(months=1 * n) 894 if unit == "week": 895 return relativedelta(weeks=1 * n) 896 if unit == "day": 897 return relativedelta(days=1 * n) 898 if unit == "hour": 899 return relativedelta(hours=1 * n) 900 if unit == "minute": 901 return relativedelta(minutes=1 * n) 902 if unit == "second": 903 return relativedelta(seconds=1 * n) 904 905 raise UnsupportedUnit(f"Unsupported unit: {unit}") 906 907 908def date_floor(d: datetime.date, unit: str) -> datetime.date: 909 if unit == "year": 910 return d.replace(month=1, day=1) 911 if unit == "quarter": 912 if d.month <= 3: 913 return d.replace(month=1, day=1) 914 elif d.month <= 6: 915 return d.replace(month=4, day=1) 916 elif d.month <= 9: 917 return d.replace(month=7, day=1) 918 else: 919 return d.replace(month=10, day=1) 920 if unit == "month": 921 return d.replace(month=d.month, day=1) 922 if unit == "week": 923 # Assuming week starts on Monday (0) and ends on Sunday (6) 924 return d - datetime.timedelta(days=d.weekday()) 925 if unit == "day": 926 return d 927 928 raise UnsupportedUnit(f"Unsupported unit: {unit}") 929 930 931def date_ceil(d: datetime.date, unit: str) -> datetime.date: 932 floor = date_floor(d, unit) 933 934 if floor == d: 935 return d 936 937 return floor + interval(unit) 938 939 940def boolean_literal(condition): 941 return exp.true() if condition else exp.false() 942 943 944def _flat_simplify(expression, simplifier, root=True): 945 if root or not expression.same_parent: 946 operands = [] 947 queue = deque(expression.flatten(unnest=False)) 948 size = len(queue) 949 950 while queue: 951 a = queue.popleft() 952 953 for b in queue: 954 result = simplifier(expression, a, b) 955 956 if result and result is not expression: 957 queue.remove(b) 958 queue.appendleft(result) 959 break 960 else: 961 operands.append(a) 962 963 if len(operands) < size: 964 return functools.reduce( 965 lambda a, b: expression.__class__(this=a, expression=b), operands 966 ) 967 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
21def simplify(expression): 22 """ 23 Rewrite sqlglot AST to simplify expressions. 24 25 Example: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 28 >>> simplify(expression).sql() 29 'TRUE' 30 31 Args: 32 expression (sqlglot.Expression): expression to simplify 33 Returns: 34 sqlglot.Expression: simplified expression 35 """ 36 37 generate = cached_generator() 38 39 # group by expressions cannot be simplified, for example 40 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 41 # the projection must exactly match the group by key 42 for group in expression.find_all(exp.Group): 43 select = group.parent 44 groups = set(group.expressions) 45 group.meta[FINAL] = True 46 47 for e in select.selects: 48 for node, *_ in e.walk(): 49 if node in groups: 50 e.meta[FINAL] = True 51 break 52 53 having = select.args.get("having") 54 if having: 55 for node, *_ in having.walk(): 56 if node in groups: 57 having.meta[FINAL] = True 58 break 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # Pre-order transformations 65 node = expression 66 node = rewrite_between(node) 67 node = uniq_sort(node, generate, root) 68 node = absorb_and_eliminate(node, root) 69 node = simplify_concat(node) 70 71 exp.replace_children(node, lambda e: _simplify(e, False)) 72 73 # Post-order transformations 74 node = simplify_not(node) 75 node = flatten(node) 76 node = simplify_connectors(node, root) 77 node = remove_compliments(node, root) 78 node = simplify_coalesce(node) 79 node.parent = expression.parent 80 node = simplify_literals(node, root) 81 node = simplify_equality(node) 82 node = simplify_parens(node) 83 node = simplify_datetrunc_predicate(node) 84 85 if root: 86 expression.replace(node) 87 88 return node 89 90 expression = while_changing(expression, _simplify) 91 remove_where_true(expression) 92 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
95def catch(*exceptions): 96 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 97 98 def decorator(func): 99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression 104 105 return wrapped 106 107 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
110def rewrite_between(expression: exp.Expression) -> exp.Expression: 111 """Rewrite x between y and z to x >= y AND x <= z. 112 113 This is done because comparison simplification is only done on lt/lte/gt/gte. 114 """ 115 if isinstance(expression, exp.Between): 116 return exp.and_( 117 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 118 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 119 copy=False, 120 ) 121 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
124def simplify_not(expression): 125 """ 126 Demorgan's Law 127 NOT (x OR y) -> NOT x AND NOT y 128 NOT (x AND y) -> NOT x OR NOT y 129 """ 130 if isinstance(expression, exp.Not): 131 if is_null(expression.this): 132 return exp.null() 133 if isinstance(expression.this, exp.Paren): 134 condition = expression.this.unnest() 135 if isinstance(condition, exp.And): 136 return exp.or_( 137 exp.not_(condition.left, copy=False), 138 exp.not_(condition.right, copy=False), 139 copy=False, 140 ) 141 if isinstance(condition, exp.Or): 142 return exp.and_( 143 exp.not_(condition.left, copy=False), 144 exp.not_(condition.right, copy=False), 145 copy=False, 146 ) 147 if is_null(condition): 148 return exp.null() 149 if always_true(expression.this): 150 return exp.false() 151 if is_false(expression.this): 152 return exp.true() 153 if isinstance(expression.this, exp.Not): 154 # double negation 155 # NOT NOT x -> x 156 return expression.this.this 157 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
160def flatten(expression): 161 """ 162 A AND (B AND C) -> A AND B AND C 163 A OR (B OR C) -> A OR B OR C 164 """ 165 if isinstance(expression, exp.Connector): 166 for node in expression.args.values(): 167 child = node.unnest() 168 if isinstance(child, expression.__class__): 169 node.replace(child) 170 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
173def simplify_connectors(expression, root=True): 174 def _simplify_connectors(expression, left, right): 175 if left == right: 176 return left 177 if isinstance(expression, exp.And): 178 if is_false(left) or is_false(right): 179 return exp.false() 180 if is_null(left) or is_null(right): 181 return exp.null() 182 if always_true(left) and always_true(right): 183 return exp.true() 184 if always_true(left): 185 return right 186 if always_true(right): 187 return left 188 return _simplify_comparison(expression, left, right) 189 elif isinstance(expression, exp.Or): 190 if always_true(left) or always_true(right): 191 return exp.true() 192 if is_false(left) and is_false(right): 193 return exp.false() 194 if ( 195 (is_null(left) and is_null(right)) 196 or (is_null(left) and is_false(right)) 197 or (is_false(left) and is_null(right)) 198 ): 199 return exp.null() 200 if is_false(left): 201 return right 202 if is_false(right): 203 return left 204 return _simplify_comparison(expression, left, right, or_=True) 205 206 if isinstance(expression, exp.Connector): 207 return _flat_simplify(expression, _simplify_connectors, root) 208 return expression
291def remove_compliments(expression, root=True): 292 """ 293 Removing compliments. 294 295 A AND NOT A -> FALSE 296 A OR NOT A -> TRUE 297 """ 298 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 299 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 300 301 for a, b in itertools.permutations(expression.flatten(), 2): 302 if is_complement(a, b): 303 return compliment 304 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
307def uniq_sort(expression, generate, root=True): 308 """ 309 Uniq and sort a connector. 310 311 C AND A AND B AND B -> A AND B AND C 312 """ 313 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 314 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 315 flattened = tuple(expression.flatten()) 316 deduped = {generate(e): e for e in flattened} 317 arr = tuple(deduped.items()) 318 319 # check if the operands are already sorted, if not sort them 320 # A AND C AND B -> A AND B AND C 321 for i, (sql, e) in enumerate(arr[1:]): 322 if sql < arr[i][0]: 323 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 324 break 325 else: 326 # we didn't have to sort but maybe we need to dedup 327 if len(deduped) < len(flattened): 328 expression = result_func(*deduped.values(), copy=False) 329 330 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
333def absorb_and_eliminate(expression, root=True): 334 """ 335 absorption: 336 A AND (A OR B) -> A 337 A OR (A AND B) -> A 338 A AND (NOT A OR B) -> A AND B 339 A OR (NOT A AND B) -> A OR B 340 elimination: 341 (A AND B) OR (A AND NOT B) -> A 342 (A OR B) AND (A OR NOT B) -> A 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 kind = exp.Or if isinstance(expression, exp.And) else exp.And 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if isinstance(a, kind): 349 aa, ab = a.unnest_operands() 350 351 # absorb 352 if is_complement(b, aa): 353 aa.replace(exp.true() if kind == exp.And else exp.false()) 354 elif is_complement(b, ab): 355 ab.replace(exp.true() if kind == exp.And else exp.false()) 356 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 357 a.replace(exp.false() if kind == exp.And else exp.true()) 358 elif isinstance(b, kind): 359 # eliminate 360 rhs = b.unnest_operands() 361 ba, bb = rhs 362 363 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 364 a.replace(aa) 365 b.replace(aa) 366 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 367 a.replace(ab) 368 b.replace(ab) 369 370 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
454def simplify_literals(expression, root=True): 455 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 456 return _flat_simplify(expression, _simplify_binary, root) 457 458 if isinstance(expression, exp.Neg): 459 this = expression.this 460 if this.is_number: 461 value = this.name 462 if value[0] == "-": 463 return exp.Literal.number(value[1:]) 464 return exp.Literal.number(f"-{value}") 465 466 return expression
529def simplify_parens(expression): 530 if not isinstance(expression, exp.Paren): 531 return expression 532 533 this = expression.this 534 parent = expression.parent 535 536 if not isinstance(this, exp.Select) and ( 537 not isinstance(parent, (exp.Condition, exp.Binary)) 538 or isinstance(parent, exp.Paren) 539 or not isinstance(this, exp.Binary) 540 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 541 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 542 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 543 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 544 ): 545 return this 546 return expression
556def simplify_coalesce(expression): 557 # COALESCE(x) -> x 558 if ( 559 isinstance(expression, exp.Coalesce) 560 and not expression.expressions 561 # COALESCE is also used as a Spark partitioning hint 562 and not isinstance(expression.parent, exp.Hint) 563 ): 564 return expression.this 565 566 if not isinstance(expression, COMPARISONS): 567 return expression 568 569 if isinstance(expression.left, exp.Coalesce): 570 coalesce = expression.left 571 other = expression.right 572 elif isinstance(expression.right, exp.Coalesce): 573 coalesce = expression.right 574 other = expression.left 575 else: 576 return expression 577 578 # This transformation is valid for non-constants, 579 # but it really only does anything if they are both constants. 580 if not isinstance(other, CONSTANTS): 581 return expression 582 583 # Find the first constant arg 584 for arg_index, arg in enumerate(coalesce.expressions): 585 if isinstance(arg, CONSTANTS): 586 break 587 else: 588 return expression 589 590 coalesce.set("expressions", coalesce.expressions[:arg_index]) 591 592 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 593 # since we already remove COALESCE at the top of this function. 594 coalesce = coalesce if coalesce.expressions else coalesce.this 595 596 # This expression is more complex than when we started, but it will get simplified further 597 return exp.paren( 598 exp.or_( 599 exp.and_( 600 coalesce.is_(exp.null()).not_(copy=False), 601 expression.copy(), 602 copy=False, 603 ), 604 exp.and_( 605 coalesce.is_(exp.null()), 606 type(expression)(this=arg.copy(), expression=other.copy()), 607 copy=False, 608 ), 609 copy=False, 610 ) 611 )
618def simplify_concat(expression): 619 """Reduces all groups that contain string literals by concatenating them.""" 620 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 621 return expression 622 623 new_args = [] 624 for is_string_group, group in itertools.groupby( 625 expression.expressions or expression.flatten(), lambda e: e.is_string 626 ): 627 if is_string_group: 628 new_args.append(exp.Literal.string("".join(string.name for string in group))) 629 else: 630 new_args.extend(group) 631 632 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 633 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 634 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
776def remove_where_true(expression): 777 for where in expression.find_all(exp.Where): 778 if always_true(where.this): 779 where.parent.set("where", None) 780 for join in expression.find_all(exp.Join): 781 if ( 782 always_true(join.args.get("on")) 783 and not join.args.get("using") 784 and not join.args.get("method") 785 and (join.side, join.kind) in JOINS 786 ): 787 join.set("on", None) 788 join.set("side", None) 789 join.set("kind", "CROSS")
810def eval_boolean(expression, a, b): 811 if isinstance(expression, (exp.EQ, exp.Is)): 812 return boolean_literal(a == b) 813 if isinstance(expression, exp.NEQ): 814 return boolean_literal(a != b) 815 if isinstance(expression, exp.GT): 816 return boolean_literal(a > b) 817 if isinstance(expression, exp.GTE): 818 return boolean_literal(a >= b) 819 if isinstance(expression, exp.LT): 820 return boolean_literal(a < b) 821 if isinstance(expression, exp.LTE): 822 return boolean_literal(a <= b) 823 return None
837def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 838 if isinstance(value, datetime.datetime): 839 return value 840 if isinstance(value, datetime.date): 841 return datetime.datetime(year=value.year, month=value.month, day=value.day) 842 try: 843 return datetime.datetime.fromisoformat(value) 844 except ValueError: 845 return None
848def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 849 if not value: 850 return None 851 if to.is_type(exp.DataType.Type.DATE): 852 return cast_as_date(value) 853 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 854 return cast_as_datetime(value) 855 return None
858def extract_date(cast: exp.Cast) -> t.Optional[t.Union[datetime.date, datetime.date]]: 859 value: t.Any 860 if isinstance(cast.this, exp.Literal): 861 value = cast.this.name 862 elif isinstance(cast.this, exp.Cast): 863 value = extract_date(cast.this) 864 else: 865 return None 866 return cast_value(value, cast.to)
886def interval(unit: str, n: int = 1): 887 from dateutil.relativedelta import relativedelta 888 889 if unit == "year": 890 return relativedelta(years=1 * n) 891 if unit == "quarter": 892 return relativedelta(months=3 * n) 893 if unit == "month": 894 return relativedelta(months=1 * n) 895 if unit == "week": 896 return relativedelta(weeks=1 * n) 897 if unit == "day": 898 return relativedelta(days=1 * n) 899 if unit == "hour": 900 return relativedelta(hours=1 * n) 901 if unit == "minute": 902 return relativedelta(minutes=1 * n) 903 if unit == "second": 904 return relativedelta(seconds=1 * n) 905 906 raise UnsupportedUnit(f"Unsupported unit: {unit}")
909def date_floor(d: datetime.date, unit: str) -> datetime.date: 910 if unit == "year": 911 return d.replace(month=1, day=1) 912 if unit == "quarter": 913 if d.month <= 3: 914 return d.replace(month=1, day=1) 915 elif d.month <= 6: 916 return d.replace(month=4, day=1) 917 elif d.month <= 9: 918 return d.replace(month=7, day=1) 919 else: 920 return d.replace(month=10, day=1) 921 if unit == "month": 922 return d.replace(month=d.month, day=1) 923 if unit == "week": 924 # Assuming week starts on Monday (0) and ends on Sunday (6) 925 return d - datetime.timedelta(days=d.weekday()) 926 if unit == "day": 927 return d 928 929 raise UnsupportedUnit(f"Unsupported unit: {unit}")