sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8from sqlglot import exp 9from sqlglot.generator import cached_generator 10from sqlglot.helper import first, merge_ranges, while_changing 11 12# Final means that an expression should not be simplified 13FINAL = "final" 14 15 16class UnsupportedUnit(Exception): 17 pass 18 19 20def simplify(expression): 21 """ 22 Rewrite sqlglot AST to simplify expressions. 23 24 Example: 25 >>> import sqlglot 26 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 27 >>> simplify(expression).sql() 28 'TRUE' 29 30 Args: 31 expression (sqlglot.Expression): expression to simplify 32 Returns: 33 sqlglot.Expression: simplified expression 34 """ 35 36 generate = cached_generator() 37 38 # group by expressions cannot be simplified, for example 39 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 40 # the projection must exactly match the group by key 41 for group in expression.find_all(exp.Group): 42 select = group.parent 43 groups = set(group.expressions) 44 group.meta[FINAL] = True 45 46 for e in select.selects: 47 for node, *_ in e.walk(): 48 if node in groups: 49 e.meta[FINAL] = True 50 break 51 52 having = select.args.get("having") 53 if having: 54 for node, *_ in having.walk(): 55 if node in groups: 56 having.meta[FINAL] = True 57 break 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # Pre-order transformations 64 node = expression 65 node = rewrite_between(node) 66 node = uniq_sort(node, generate, root) 67 node = absorb_and_eliminate(node, root) 68 node = simplify_concat(node) 69 70 exp.replace_children(node, lambda e: _simplify(e, False)) 71 72 # Post-order transformations 73 node = simplify_not(node) 74 node = flatten(node) 75 node = simplify_connectors(node, root) 76 node = remove_compliments(node, root) 77 node = simplify_coalesce(node) 78 node.parent = expression.parent 79 node = simplify_literals(node, root) 80 node = simplify_equality(node) 81 node = simplify_parens(node) 82 node = simplify_datetrunc_predicate(node) 83 84 if root: 85 expression.replace(node) 86 87 return node 88 89 expression = while_changing(expression, _simplify) 90 remove_where_true(expression) 91 return expression 92 93 94def catch(*exceptions): 95 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 96 97 def decorator(func): 98 def wrapped(expression, *args, **kwargs): 99 try: 100 return func(expression, *args, **kwargs) 101 except exceptions: 102 return expression 103 104 return wrapped 105 106 return decorator 107 108 109def rewrite_between(expression: exp.Expression) -> exp.Expression: 110 """Rewrite x between y and z to x >= y AND x <= z. 111 112 This is done because comparison simplification is only done on lt/lte/gt/gte. 113 """ 114 if isinstance(expression, exp.Between): 115 return exp.and_( 116 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 117 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 118 copy=False, 119 ) 120 return expression 121 122 123def simplify_not(expression): 124 """ 125 Demorgan's Law 126 NOT (x OR y) -> NOT x AND NOT y 127 NOT (x AND y) -> NOT x OR NOT y 128 """ 129 if isinstance(expression, exp.Not): 130 if is_null(expression.this): 131 return exp.null() 132 if isinstance(expression.this, exp.Paren): 133 condition = expression.this.unnest() 134 if isinstance(condition, exp.And): 135 return exp.or_( 136 exp.not_(condition.left, copy=False), 137 exp.not_(condition.right, copy=False), 138 copy=False, 139 ) 140 if isinstance(condition, exp.Or): 141 return exp.and_( 142 exp.not_(condition.left, copy=False), 143 exp.not_(condition.right, copy=False), 144 copy=False, 145 ) 146 if is_null(condition): 147 return exp.null() 148 if always_true(expression.this): 149 return exp.false() 150 if is_false(expression.this): 151 return exp.true() 152 if isinstance(expression.this, exp.Not): 153 # double negation 154 # NOT NOT x -> x 155 return expression.this.this 156 return expression 157 158 159def flatten(expression): 160 """ 161 A AND (B AND C) -> A AND B AND C 162 A OR (B OR C) -> A OR B OR C 163 """ 164 if isinstance(expression, exp.Connector): 165 for node in expression.args.values(): 166 child = node.unnest() 167 if isinstance(child, expression.__class__): 168 node.replace(child) 169 return expression 170 171 172def simplify_connectors(expression, root=True): 173 def _simplify_connectors(expression, left, right): 174 if left == right: 175 return left 176 if isinstance(expression, exp.And): 177 if is_false(left) or is_false(right): 178 return exp.false() 179 if is_null(left) or is_null(right): 180 return exp.null() 181 if always_true(left) and always_true(right): 182 return exp.true() 183 if always_true(left): 184 return right 185 if always_true(right): 186 return left 187 return _simplify_comparison(expression, left, right) 188 elif isinstance(expression, exp.Or): 189 if always_true(left) or always_true(right): 190 return exp.true() 191 if is_false(left) and is_false(right): 192 return exp.false() 193 if ( 194 (is_null(left) and is_null(right)) 195 or (is_null(left) and is_false(right)) 196 or (is_false(left) and is_null(right)) 197 ): 198 return exp.null() 199 if is_false(left): 200 return right 201 if is_false(right): 202 return left 203 return _simplify_comparison(expression, left, right, or_=True) 204 205 if isinstance(expression, exp.Connector): 206 return _flat_simplify(expression, _simplify_connectors, root) 207 return expression 208 209 210LT_LTE = (exp.LT, exp.LTE) 211GT_GTE = (exp.GT, exp.GTE) 212 213COMPARISONS = ( 214 *LT_LTE, 215 *GT_GTE, 216 exp.EQ, 217 exp.NEQ, 218 exp.Is, 219) 220 221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 222 exp.LT: exp.GT, 223 exp.GT: exp.LT, 224 exp.LTE: exp.GTE, 225 exp.GTE: exp.LTE, 226} 227 228 229def _simplify_comparison(expression, left, right, or_=False): 230 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 231 ll, lr = left.args.values() 232 rl, rr = right.args.values() 233 234 largs = {ll, lr} 235 rargs = {rl, rr} 236 237 matching = largs & rargs 238 columns = {m for m in matching if isinstance(m, exp.Column)} 239 240 if matching and columns: 241 try: 242 l = first(largs - columns) 243 r = first(rargs - columns) 244 except StopIteration: 245 return expression 246 247 # make sure the comparison is always of the form x > 1 instead of 1 < x 248 if left.__class__ in INVERSE_COMPARISONS and l == ll: 249 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 250 if right.__class__ in INVERSE_COMPARISONS and r == rl: 251 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 252 253 if l.is_number and r.is_number: 254 l = float(l.name) 255 r = float(r.name) 256 elif l.is_string and r.is_string: 257 l = l.name 258 r = r.name 259 else: 260 return None 261 262 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 263 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 264 return left if (av > bv if or_ else av <= bv) else right 265 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 266 return left if (av < bv if or_ else av >= bv) else right 267 268 # we can't ever shortcut to true because the column could be null 269 if not or_: 270 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 271 if av <= bv: 272 return exp.false() 273 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 274 if av >= bv: 275 return exp.false() 276 elif isinstance(a, exp.EQ): 277 if isinstance(b, exp.LT): 278 return exp.false() if av >= bv else a 279 if isinstance(b, exp.LTE): 280 return exp.false() if av > bv else a 281 if isinstance(b, exp.GT): 282 return exp.false() if av <= bv else a 283 if isinstance(b, exp.GTE): 284 return exp.false() if av < bv else a 285 if isinstance(b, exp.NEQ): 286 return exp.false() if av == bv else a 287 return None 288 289 290def remove_compliments(expression, root=True): 291 """ 292 Removing compliments. 293 294 A AND NOT A -> FALSE 295 A OR NOT A -> TRUE 296 """ 297 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 298 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 299 300 for a, b in itertools.permutations(expression.flatten(), 2): 301 if is_complement(a, b): 302 return compliment 303 return expression 304 305 306def uniq_sort(expression, generate, root=True): 307 """ 308 Uniq and sort a connector. 309 310 C AND A AND B AND B -> A AND B AND C 311 """ 312 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 313 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 314 flattened = tuple(expression.flatten()) 315 deduped = {generate(e): e for e in flattened} 316 arr = tuple(deduped.items()) 317 318 # check if the operands are already sorted, if not sort them 319 # A AND C AND B -> A AND B AND C 320 for i, (sql, e) in enumerate(arr[1:]): 321 if sql < arr[i][0]: 322 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 323 break 324 else: 325 # we didn't have to sort but maybe we need to dedup 326 if len(deduped) < len(flattened): 327 expression = result_func(*deduped.values(), copy=False) 328 329 return expression 330 331 332def absorb_and_eliminate(expression, root=True): 333 """ 334 absorption: 335 A AND (A OR B) -> A 336 A OR (A AND B) -> A 337 A AND (NOT A OR B) -> A AND B 338 A OR (NOT A AND B) -> A OR B 339 elimination: 340 (A AND B) OR (A AND NOT B) -> A 341 (A OR B) AND (A OR NOT B) -> A 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 kind = exp.Or if isinstance(expression, exp.And) else exp.And 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if isinstance(a, kind): 348 aa, ab = a.unnest_operands() 349 350 # absorb 351 if is_complement(b, aa): 352 aa.replace(exp.true() if kind == exp.And else exp.false()) 353 elif is_complement(b, ab): 354 ab.replace(exp.true() if kind == exp.And else exp.false()) 355 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 356 a.replace(exp.false() if kind == exp.And else exp.true()) 357 elif isinstance(b, kind): 358 # eliminate 359 rhs = b.unnest_operands() 360 ba, bb = rhs 361 362 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 363 a.replace(aa) 364 b.replace(aa) 365 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 366 a.replace(ab) 367 b.replace(ab) 368 369 return expression 370 371 372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 373 exp.DateAdd: exp.Sub, 374 exp.DateSub: exp.Add, 375 exp.DatetimeAdd: exp.Sub, 376 exp.DatetimeSub: exp.Add, 377} 378 379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 380 **INVERSE_DATE_OPS, 381 exp.Add: exp.Sub, 382 exp.Sub: exp.Add, 383} 384 385 386def _is_number(expression: exp.Expression) -> bool: 387 return expression.is_number 388 389 390def _is_interval(expression: exp.Expression) -> bool: 391 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 392 393 394@catch(ModuleNotFoundError, UnsupportedUnit) 395def simplify_equality(expression: exp.Expression) -> exp.Expression: 396 """ 397 Use the subtraction and addition properties of equality to simplify expressions: 398 399 x + 1 = 3 becomes x = 2 400 401 There are two binary operations in the above expression: + and = 402 Here's how we reference all the operands in the code below: 403 404 l r 405 x + 1 = 3 406 a b 407 """ 408 if isinstance(expression, COMPARISONS): 409 l, r = expression.left, expression.right 410 411 if l.__class__ in INVERSE_OPS: 412 pass 413 elif r.__class__ in INVERSE_OPS: 414 l, r = r, l 415 else: 416 return expression 417 418 if r.is_number: 419 a_predicate = _is_number 420 b_predicate = _is_number 421 elif _is_date_literal(r): 422 a_predicate = _is_date_literal 423 b_predicate = _is_interval 424 else: 425 return expression 426 427 if l.__class__ in INVERSE_DATE_OPS: 428 a = l.this 429 b = exp.Interval( 430 this=l.expression.copy(), 431 unit=l.unit.copy(), 432 ) 433 else: 434 a, b = l.left, l.right 435 436 if not a_predicate(a) and b_predicate(b): 437 pass 438 elif not a_predicate(b) and b_predicate(a): 439 a, b = b, a 440 else: 441 return expression 442 443 return expression.__class__( 444 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 445 ) 446 return expression 447 448 449def simplify_literals(expression, root=True): 450 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 451 return _flat_simplify(expression, _simplify_binary, root) 452 453 if isinstance(expression, exp.Neg): 454 this = expression.this 455 if this.is_number: 456 value = this.name 457 if value[0] == "-": 458 return exp.Literal.number(value[1:]) 459 return exp.Literal.number(f"-{value}") 460 461 return expression 462 463 464def _simplify_binary(expression, a, b): 465 if isinstance(expression, exp.Is): 466 if isinstance(b, exp.Not): 467 c = b.this 468 not_ = True 469 else: 470 c = b 471 not_ = False 472 473 if is_null(c): 474 if isinstance(a, exp.Literal): 475 return exp.true() if not_ else exp.false() 476 if is_null(a): 477 return exp.false() if not_ else exp.true() 478 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 479 return None 480 elif is_null(a) or is_null(b): 481 return exp.null() 482 483 if a.is_number and b.is_number: 484 a = int(a.name) if a.is_int else Decimal(a.name) 485 b = int(b.name) if b.is_int else Decimal(b.name) 486 487 if isinstance(expression, exp.Add): 488 return exp.Literal.number(a + b) 489 if isinstance(expression, exp.Sub): 490 return exp.Literal.number(a - b) 491 if isinstance(expression, exp.Mul): 492 return exp.Literal.number(a * b) 493 if isinstance(expression, exp.Div): 494 # engines have differing int div behavior so intdiv is not safe 495 if isinstance(a, int) and isinstance(b, int): 496 return None 497 return exp.Literal.number(a / b) 498 499 boolean = eval_boolean(expression, a, b) 500 501 if boolean: 502 return boolean 503 elif a.is_string and b.is_string: 504 boolean = eval_boolean(expression, a.this, b.this) 505 506 if boolean: 507 return boolean 508 elif _is_date_literal(a) and isinstance(b, exp.Interval): 509 a, b = extract_date(a), extract_interval(b) 510 if a and b: 511 if isinstance(expression, exp.Add): 512 return date_literal(a + b) 513 if isinstance(expression, exp.Sub): 514 return date_literal(a - b) 515 elif isinstance(a, exp.Interval) and _is_date_literal(b): 516 a, b = extract_interval(a), extract_date(b) 517 # you cannot subtract a date from an interval 518 if a and b and isinstance(expression, exp.Add): 519 return date_literal(a + b) 520 521 return None 522 523 524def simplify_parens(expression): 525 if not isinstance(expression, exp.Paren): 526 return expression 527 528 this = expression.this 529 parent = expression.parent 530 531 if not isinstance(this, exp.Select) and ( 532 not isinstance(parent, (exp.Condition, exp.Binary)) 533 or isinstance(parent, exp.Paren) 534 or not isinstance(this, exp.Binary) 535 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 536 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 537 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 538 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 539 ): 540 return this 541 return expression 542 543 544CONSTANTS = ( 545 exp.Literal, 546 exp.Boolean, 547 exp.Null, 548) 549 550 551def simplify_coalesce(expression): 552 # COALESCE(x) -> x 553 if ( 554 isinstance(expression, exp.Coalesce) 555 and not expression.expressions 556 # COALESCE is also used as a Spark partitioning hint 557 and not isinstance(expression.parent, exp.Hint) 558 ): 559 return expression.this 560 561 if not isinstance(expression, COMPARISONS): 562 return expression 563 564 if isinstance(expression.left, exp.Coalesce): 565 coalesce = expression.left 566 other = expression.right 567 elif isinstance(expression.right, exp.Coalesce): 568 coalesce = expression.right 569 other = expression.left 570 else: 571 return expression 572 573 # This transformation is valid for non-constants, 574 # but it really only does anything if they are both constants. 575 if not isinstance(other, CONSTANTS): 576 return expression 577 578 # Find the first constant arg 579 for arg_index, arg in enumerate(coalesce.expressions): 580 if isinstance(arg, CONSTANTS): 581 break 582 else: 583 return expression 584 585 coalesce.set("expressions", coalesce.expressions[:arg_index]) 586 587 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 588 # since we already remove COALESCE at the top of this function. 589 coalesce = coalesce if coalesce.expressions else coalesce.this 590 591 # This expression is more complex than when we started, but it will get simplified further 592 return exp.paren( 593 exp.or_( 594 exp.and_( 595 coalesce.is_(exp.null()).not_(copy=False), 596 expression.copy(), 597 copy=False, 598 ), 599 exp.and_( 600 coalesce.is_(exp.null()), 601 type(expression)(this=arg.copy(), expression=other.copy()), 602 copy=False, 603 ), 604 copy=False, 605 ) 606 ) 607 608 609CONCATS = (exp.Concat, exp.DPipe) 610SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 611 612 613def simplify_concat(expression): 614 """Reduces all groups that contain string literals by concatenating them.""" 615 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 616 return expression 617 618 new_args = [] 619 for is_string_group, group in itertools.groupby( 620 expression.expressions or expression.flatten(), lambda e: e.is_string 621 ): 622 if is_string_group: 623 new_args.append(exp.Literal.string("".join(string.name for string in group))) 624 else: 625 new_args.extend(group) 626 627 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 628 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 629 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 630 631 632DateRange = t.Tuple[datetime.date, datetime.date] 633 634 635def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 636 """ 637 Get the date range for a DATE_TRUNC equality comparison: 638 639 Example: 640 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 641 Returns: 642 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 643 """ 644 floor = date_floor(date, unit) 645 646 if date != floor: 647 # This will always be False, except for NULL values. 648 return None 649 650 return floor, floor + interval(unit) 651 652 653def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 654 """Get the logical expression for a date range""" 655 return exp.and_( 656 left >= date_literal(drange[0]), 657 left < date_literal(drange[1]), 658 copy=False, 659 ) 660 661 662def _datetrunc_eq( 663 left: exp.Expression, date: datetime.date, unit: str 664) -> t.Optional[exp.Expression]: 665 drange = _datetrunc_range(date, unit) 666 if not drange: 667 return None 668 669 return _datetrunc_eq_expression(left, drange) 670 671 672def _datetrunc_neq( 673 left: exp.Expression, date: datetime.date, unit: str 674) -> t.Optional[exp.Expression]: 675 drange = _datetrunc_range(date, unit) 676 if not drange: 677 return None 678 679 return exp.and_( 680 left < date_literal(drange[0]), 681 left >= date_literal(drange[1]), 682 copy=False, 683 ) 684 685 686DateTruncBinaryTransform = t.Callable[ 687 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 688] 689DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 690 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 691 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 692 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 693 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 694 exp.EQ: _datetrunc_eq, 695 exp.NEQ: _datetrunc_neq, 696} 697DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 698 699 700def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 701 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 702 703 704@catch(ModuleNotFoundError, UnsupportedUnit) 705def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 706 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 707 comparison = expression.__class__ 708 709 if comparison not in DATETRUNC_COMPARISONS: 710 return expression 711 712 if isinstance(expression, exp.Binary): 713 l, r = expression.left, expression.right 714 715 if _is_datetrunc_predicate(l, r): 716 pass 717 elif _is_datetrunc_predicate(r, l): 718 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 719 l, r = r, l 720 else: 721 return expression 722 723 unit = l.unit.name.lower() 724 date = extract_date(r) 725 726 if not date: 727 return expression 728 729 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 730 elif isinstance(expression, exp.In): 731 l = expression.this 732 rs = expression.expressions 733 734 if all(_is_datetrunc_predicate(l, r) for r in rs): 735 unit = l.unit.name.lower() 736 737 ranges = [] 738 for r in rs: 739 date = extract_date(r) 740 if not date: 741 return expression 742 drange = _datetrunc_range(date, unit) 743 if drange: 744 ranges.append(drange) 745 746 if not ranges: 747 return expression 748 749 ranges = merge_ranges(ranges) 750 751 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 752 753 return expression 754 755 756# CROSS joins result in an empty table if the right table is empty. 757# So we can only simplify certain types of joins to CROSS. 758# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 759JOINS = { 760 ("", ""), 761 ("", "INNER"), 762 ("RIGHT", ""), 763 ("RIGHT", "OUTER"), 764} 765 766 767def remove_where_true(expression): 768 for where in expression.find_all(exp.Where): 769 if always_true(where.this): 770 where.parent.set("where", None) 771 for join in expression.find_all(exp.Join): 772 if ( 773 always_true(join.args.get("on")) 774 and not join.args.get("using") 775 and not join.args.get("method") 776 and (join.side, join.kind) in JOINS 777 ): 778 join.set("on", None) 779 join.set("side", None) 780 join.set("kind", "CROSS") 781 782 783def always_true(expression): 784 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 785 expression, exp.Literal 786 ) 787 788 789def is_complement(a, b): 790 return isinstance(b, exp.Not) and b.this == a 791 792 793def is_false(a: exp.Expression) -> bool: 794 return type(a) is exp.Boolean and not a.this 795 796 797def is_null(a: exp.Expression) -> bool: 798 return type(a) is exp.Null 799 800 801def eval_boolean(expression, a, b): 802 if isinstance(expression, (exp.EQ, exp.Is)): 803 return boolean_literal(a == b) 804 if isinstance(expression, exp.NEQ): 805 return boolean_literal(a != b) 806 if isinstance(expression, exp.GT): 807 return boolean_literal(a > b) 808 if isinstance(expression, exp.GTE): 809 return boolean_literal(a >= b) 810 if isinstance(expression, exp.LT): 811 return boolean_literal(a < b) 812 if isinstance(expression, exp.LTE): 813 return boolean_literal(a <= b) 814 return None 815 816 817def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 818 if isinstance(value, datetime.datetime): 819 return value.date() 820 if isinstance(value, datetime.date): 821 return value 822 try: 823 return datetime.datetime.fromisoformat(value).date() 824 except ValueError: 825 return None 826 827 828def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 829 if isinstance(value, datetime.datetime): 830 return value 831 if isinstance(value, datetime.date): 832 return datetime.datetime(year=value.year, month=value.month, day=value.day) 833 try: 834 return datetime.datetime.fromisoformat(value) 835 except ValueError: 836 return None 837 838 839def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 840 if not value: 841 return None 842 if to.is_type(exp.DataType.Type.DATE): 843 return cast_as_date(value) 844 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 845 return cast_as_datetime(value) 846 return None 847 848 849def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 850 if isinstance(cast, exp.Cast): 851 to = cast.to 852 elif isinstance(cast, exp.TsOrDsToDate): 853 to = exp.DataType.build(exp.DataType.Type.DATE) 854 else: 855 return None 856 857 if isinstance(cast.this, exp.Literal): 858 value: t.Any = cast.this.name 859 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 860 value = extract_date(cast.this) 861 else: 862 return None 863 return cast_value(value, to) 864 865 866def _is_date_literal(expression: exp.Expression) -> bool: 867 return extract_date(expression) is not None 868 869 870def extract_interval(expression): 871 n = int(expression.name) 872 unit = expression.text("unit").lower() 873 874 try: 875 return interval(unit, n) 876 except (UnsupportedUnit, ModuleNotFoundError): 877 return None 878 879 880def date_literal(date): 881 return exp.cast( 882 exp.Literal.string(date), 883 exp.DataType.Type.DATETIME 884 if isinstance(date, datetime.datetime) 885 else exp.DataType.Type.DATE, 886 ) 887 888 889def interval(unit: str, n: int = 1): 890 from dateutil.relativedelta import relativedelta 891 892 if unit == "year": 893 return relativedelta(years=1 * n) 894 if unit == "quarter": 895 return relativedelta(months=3 * n) 896 if unit == "month": 897 return relativedelta(months=1 * n) 898 if unit == "week": 899 return relativedelta(weeks=1 * n) 900 if unit == "day": 901 return relativedelta(days=1 * n) 902 if unit == "hour": 903 return relativedelta(hours=1 * n) 904 if unit == "minute": 905 return relativedelta(minutes=1 * n) 906 if unit == "second": 907 return relativedelta(seconds=1 * n) 908 909 raise UnsupportedUnit(f"Unsupported unit: {unit}") 910 911 912def date_floor(d: datetime.date, unit: str) -> datetime.date: 913 if unit == "year": 914 return d.replace(month=1, day=1) 915 if unit == "quarter": 916 if d.month <= 3: 917 return d.replace(month=1, day=1) 918 elif d.month <= 6: 919 return d.replace(month=4, day=1) 920 elif d.month <= 9: 921 return d.replace(month=7, day=1) 922 else: 923 return d.replace(month=10, day=1) 924 if unit == "month": 925 return d.replace(month=d.month, day=1) 926 if unit == "week": 927 # Assuming week starts on Monday (0) and ends on Sunday (6) 928 return d - datetime.timedelta(days=d.weekday()) 929 if unit == "day": 930 return d 931 932 raise UnsupportedUnit(f"Unsupported unit: {unit}") 933 934 935def date_ceil(d: datetime.date, unit: str) -> datetime.date: 936 floor = date_floor(d, unit) 937 938 if floor == d: 939 return d 940 941 return floor + interval(unit) 942 943 944def boolean_literal(condition): 945 return exp.true() if condition else exp.false() 946 947 948def _flat_simplify(expression, simplifier, root=True): 949 if root or not expression.same_parent: 950 operands = [] 951 queue = deque(expression.flatten(unnest=False)) 952 size = len(queue) 953 954 while queue: 955 a = queue.popleft() 956 957 for b in queue: 958 result = simplifier(expression, a, b) 959 960 if result and result is not expression: 961 queue.remove(b) 962 queue.appendleft(result) 963 break 964 else: 965 operands.append(a) 966 967 if len(operands) < size: 968 return functools.reduce( 969 lambda a, b: expression.__class__(this=a, expression=b), operands 970 ) 971 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
21def simplify(expression): 22 """ 23 Rewrite sqlglot AST to simplify expressions. 24 25 Example: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 28 >>> simplify(expression).sql() 29 'TRUE' 30 31 Args: 32 expression (sqlglot.Expression): expression to simplify 33 Returns: 34 sqlglot.Expression: simplified expression 35 """ 36 37 generate = cached_generator() 38 39 # group by expressions cannot be simplified, for example 40 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 41 # the projection must exactly match the group by key 42 for group in expression.find_all(exp.Group): 43 select = group.parent 44 groups = set(group.expressions) 45 group.meta[FINAL] = True 46 47 for e in select.selects: 48 for node, *_ in e.walk(): 49 if node in groups: 50 e.meta[FINAL] = True 51 break 52 53 having = select.args.get("having") 54 if having: 55 for node, *_ in having.walk(): 56 if node in groups: 57 having.meta[FINAL] = True 58 break 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # Pre-order transformations 65 node = expression 66 node = rewrite_between(node) 67 node = uniq_sort(node, generate, root) 68 node = absorb_and_eliminate(node, root) 69 node = simplify_concat(node) 70 71 exp.replace_children(node, lambda e: _simplify(e, False)) 72 73 # Post-order transformations 74 node = simplify_not(node) 75 node = flatten(node) 76 node = simplify_connectors(node, root) 77 node = remove_compliments(node, root) 78 node = simplify_coalesce(node) 79 node.parent = expression.parent 80 node = simplify_literals(node, root) 81 node = simplify_equality(node) 82 node = simplify_parens(node) 83 node = simplify_datetrunc_predicate(node) 84 85 if root: 86 expression.replace(node) 87 88 return node 89 90 expression = while_changing(expression, _simplify) 91 remove_where_true(expression) 92 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
95def catch(*exceptions): 96 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 97 98 def decorator(func): 99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression 104 105 return wrapped 106 107 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
110def rewrite_between(expression: exp.Expression) -> exp.Expression: 111 """Rewrite x between y and z to x >= y AND x <= z. 112 113 This is done because comparison simplification is only done on lt/lte/gt/gte. 114 """ 115 if isinstance(expression, exp.Between): 116 return exp.and_( 117 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 118 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 119 copy=False, 120 ) 121 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
124def simplify_not(expression): 125 """ 126 Demorgan's Law 127 NOT (x OR y) -> NOT x AND NOT y 128 NOT (x AND y) -> NOT x OR NOT y 129 """ 130 if isinstance(expression, exp.Not): 131 if is_null(expression.this): 132 return exp.null() 133 if isinstance(expression.this, exp.Paren): 134 condition = expression.this.unnest() 135 if isinstance(condition, exp.And): 136 return exp.or_( 137 exp.not_(condition.left, copy=False), 138 exp.not_(condition.right, copy=False), 139 copy=False, 140 ) 141 if isinstance(condition, exp.Or): 142 return exp.and_( 143 exp.not_(condition.left, copy=False), 144 exp.not_(condition.right, copy=False), 145 copy=False, 146 ) 147 if is_null(condition): 148 return exp.null() 149 if always_true(expression.this): 150 return exp.false() 151 if is_false(expression.this): 152 return exp.true() 153 if isinstance(expression.this, exp.Not): 154 # double negation 155 # NOT NOT x -> x 156 return expression.this.this 157 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
160def flatten(expression): 161 """ 162 A AND (B AND C) -> A AND B AND C 163 A OR (B OR C) -> A OR B OR C 164 """ 165 if isinstance(expression, exp.Connector): 166 for node in expression.args.values(): 167 child = node.unnest() 168 if isinstance(child, expression.__class__): 169 node.replace(child) 170 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
173def simplify_connectors(expression, root=True): 174 def _simplify_connectors(expression, left, right): 175 if left == right: 176 return left 177 if isinstance(expression, exp.And): 178 if is_false(left) or is_false(right): 179 return exp.false() 180 if is_null(left) or is_null(right): 181 return exp.null() 182 if always_true(left) and always_true(right): 183 return exp.true() 184 if always_true(left): 185 return right 186 if always_true(right): 187 return left 188 return _simplify_comparison(expression, left, right) 189 elif isinstance(expression, exp.Or): 190 if always_true(left) or always_true(right): 191 return exp.true() 192 if is_false(left) and is_false(right): 193 return exp.false() 194 if ( 195 (is_null(left) and is_null(right)) 196 or (is_null(left) and is_false(right)) 197 or (is_false(left) and is_null(right)) 198 ): 199 return exp.null() 200 if is_false(left): 201 return right 202 if is_false(right): 203 return left 204 return _simplify_comparison(expression, left, right, or_=True) 205 206 if isinstance(expression, exp.Connector): 207 return _flat_simplify(expression, _simplify_connectors, root) 208 return expression
291def remove_compliments(expression, root=True): 292 """ 293 Removing compliments. 294 295 A AND NOT A -> FALSE 296 A OR NOT A -> TRUE 297 """ 298 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 299 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 300 301 for a, b in itertools.permutations(expression.flatten(), 2): 302 if is_complement(a, b): 303 return compliment 304 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
307def uniq_sort(expression, generate, root=True): 308 """ 309 Uniq and sort a connector. 310 311 C AND A AND B AND B -> A AND B AND C 312 """ 313 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 314 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 315 flattened = tuple(expression.flatten()) 316 deduped = {generate(e): e for e in flattened} 317 arr = tuple(deduped.items()) 318 319 # check if the operands are already sorted, if not sort them 320 # A AND C AND B -> A AND B AND C 321 for i, (sql, e) in enumerate(arr[1:]): 322 if sql < arr[i][0]: 323 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 324 break 325 else: 326 # we didn't have to sort but maybe we need to dedup 327 if len(deduped) < len(flattened): 328 expression = result_func(*deduped.values(), copy=False) 329 330 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
333def absorb_and_eliminate(expression, root=True): 334 """ 335 absorption: 336 A AND (A OR B) -> A 337 A OR (A AND B) -> A 338 A AND (NOT A OR B) -> A AND B 339 A OR (NOT A AND B) -> A OR B 340 elimination: 341 (A AND B) OR (A AND NOT B) -> A 342 (A OR B) AND (A OR NOT B) -> A 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 kind = exp.Or if isinstance(expression, exp.And) else exp.And 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if isinstance(a, kind): 349 aa, ab = a.unnest_operands() 350 351 # absorb 352 if is_complement(b, aa): 353 aa.replace(exp.true() if kind == exp.And else exp.false()) 354 elif is_complement(b, ab): 355 ab.replace(exp.true() if kind == exp.And else exp.false()) 356 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 357 a.replace(exp.false() if kind == exp.And else exp.true()) 358 elif isinstance(b, kind): 359 # eliminate 360 rhs = b.unnest_operands() 361 ba, bb = rhs 362 363 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 364 a.replace(aa) 365 b.replace(aa) 366 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 367 a.replace(ab) 368 b.replace(ab) 369 370 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
450def simplify_literals(expression, root=True): 451 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 452 return _flat_simplify(expression, _simplify_binary, root) 453 454 if isinstance(expression, exp.Neg): 455 this = expression.this 456 if this.is_number: 457 value = this.name 458 if value[0] == "-": 459 return exp.Literal.number(value[1:]) 460 return exp.Literal.number(f"-{value}") 461 462 return expression
525def simplify_parens(expression): 526 if not isinstance(expression, exp.Paren): 527 return expression 528 529 this = expression.this 530 parent = expression.parent 531 532 if not isinstance(this, exp.Select) and ( 533 not isinstance(parent, (exp.Condition, exp.Binary)) 534 or isinstance(parent, exp.Paren) 535 or not isinstance(this, exp.Binary) 536 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 537 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 538 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 539 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 540 ): 541 return this 542 return expression
552def simplify_coalesce(expression): 553 # COALESCE(x) -> x 554 if ( 555 isinstance(expression, exp.Coalesce) 556 and not expression.expressions 557 # COALESCE is also used as a Spark partitioning hint 558 and not isinstance(expression.parent, exp.Hint) 559 ): 560 return expression.this 561 562 if not isinstance(expression, COMPARISONS): 563 return expression 564 565 if isinstance(expression.left, exp.Coalesce): 566 coalesce = expression.left 567 other = expression.right 568 elif isinstance(expression.right, exp.Coalesce): 569 coalesce = expression.right 570 other = expression.left 571 else: 572 return expression 573 574 # This transformation is valid for non-constants, 575 # but it really only does anything if they are both constants. 576 if not isinstance(other, CONSTANTS): 577 return expression 578 579 # Find the first constant arg 580 for arg_index, arg in enumerate(coalesce.expressions): 581 if isinstance(arg, CONSTANTS): 582 break 583 else: 584 return expression 585 586 coalesce.set("expressions", coalesce.expressions[:arg_index]) 587 588 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 589 # since we already remove COALESCE at the top of this function. 590 coalesce = coalesce if coalesce.expressions else coalesce.this 591 592 # This expression is more complex than when we started, but it will get simplified further 593 return exp.paren( 594 exp.or_( 595 exp.and_( 596 coalesce.is_(exp.null()).not_(copy=False), 597 expression.copy(), 598 copy=False, 599 ), 600 exp.and_( 601 coalesce.is_(exp.null()), 602 type(expression)(this=arg.copy(), expression=other.copy()), 603 copy=False, 604 ), 605 copy=False, 606 ) 607 )
614def simplify_concat(expression): 615 """Reduces all groups that contain string literals by concatenating them.""" 616 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 617 return expression 618 619 new_args = [] 620 for is_string_group, group in itertools.groupby( 621 expression.expressions or expression.flatten(), lambda e: e.is_string 622 ): 623 if is_string_group: 624 new_args.append(exp.Literal.string("".join(string.name for string in group))) 625 else: 626 new_args.extend(group) 627 628 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 629 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 630 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
768def remove_where_true(expression): 769 for where in expression.find_all(exp.Where): 770 if always_true(where.this): 771 where.parent.set("where", None) 772 for join in expression.find_all(exp.Join): 773 if ( 774 always_true(join.args.get("on")) 775 and not join.args.get("using") 776 and not join.args.get("method") 777 and (join.side, join.kind) in JOINS 778 ): 779 join.set("on", None) 780 join.set("side", None) 781 join.set("kind", "CROSS")
802def eval_boolean(expression, a, b): 803 if isinstance(expression, (exp.EQ, exp.Is)): 804 return boolean_literal(a == b) 805 if isinstance(expression, exp.NEQ): 806 return boolean_literal(a != b) 807 if isinstance(expression, exp.GT): 808 return boolean_literal(a > b) 809 if isinstance(expression, exp.GTE): 810 return boolean_literal(a >= b) 811 if isinstance(expression, exp.LT): 812 return boolean_literal(a < b) 813 if isinstance(expression, exp.LTE): 814 return boolean_literal(a <= b) 815 return None
829def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 830 if isinstance(value, datetime.datetime): 831 return value 832 if isinstance(value, datetime.date): 833 return datetime.datetime(year=value.year, month=value.month, day=value.day) 834 try: 835 return datetime.datetime.fromisoformat(value) 836 except ValueError: 837 return None
840def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 841 if not value: 842 return None 843 if to.is_type(exp.DataType.Type.DATE): 844 return cast_as_date(value) 845 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 846 return cast_as_datetime(value) 847 return None
850def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 851 if isinstance(cast, exp.Cast): 852 to = cast.to 853 elif isinstance(cast, exp.TsOrDsToDate): 854 to = exp.DataType.build(exp.DataType.Type.DATE) 855 else: 856 return None 857 858 if isinstance(cast.this, exp.Literal): 859 value: t.Any = cast.this.name 860 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 861 value = extract_date(cast.this) 862 else: 863 return None 864 return cast_value(value, to)
890def interval(unit: str, n: int = 1): 891 from dateutil.relativedelta import relativedelta 892 893 if unit == "year": 894 return relativedelta(years=1 * n) 895 if unit == "quarter": 896 return relativedelta(months=3 * n) 897 if unit == "month": 898 return relativedelta(months=1 * n) 899 if unit == "week": 900 return relativedelta(weeks=1 * n) 901 if unit == "day": 902 return relativedelta(days=1 * n) 903 if unit == "hour": 904 return relativedelta(hours=1 * n) 905 if unit == "minute": 906 return relativedelta(minutes=1 * n) 907 if unit == "second": 908 return relativedelta(seconds=1 * n) 909 910 raise UnsupportedUnit(f"Unsupported unit: {unit}")
913def date_floor(d: datetime.date, unit: str) -> datetime.date: 914 if unit == "year": 915 return d.replace(month=1, day=1) 916 if unit == "quarter": 917 if d.month <= 3: 918 return d.replace(month=1, day=1) 919 elif d.month <= 6: 920 return d.replace(month=4, day=1) 921 elif d.month <= 9: 922 return d.replace(month=7, day=1) 923 else: 924 return d.replace(month=10, day=1) 925 if unit == "month": 926 return d.replace(month=d.month, day=1) 927 if unit == "week": 928 # Assuming week starts on Monday (0) and ends on Sunday (6) 929 return d - datetime.timedelta(days=d.weekday()) 930 if unit == "day": 931 return d 932 933 raise UnsupportedUnit(f"Unsupported unit: {unit}")