sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25 26class UnsupportedUnit(Exception): 27 pass 28 29 30def simplify( 31 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 32): 33 """ 34 Rewrite sqlglot AST to simplify expressions. 35 36 Example: 37 >>> import sqlglot 38 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 39 >>> simplify(expression).sql() 40 'TRUE' 41 42 Args: 43 expression (sqlglot.Expression): expression to simplify 44 constant_propagation: whether the constant propagation rule should be used 45 46 Returns: 47 sqlglot.Expression: simplified expression 48 """ 49 50 dialect = Dialect.get_or_raise(dialect) 51 52 def _simplify(expression, root=True): 53 if expression.meta.get(FINAL): 54 return expression 55 56 # group by expressions cannot be simplified, for example 57 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 58 # the projection must exactly match the group by key 59 group = expression.args.get("group") 60 61 if group and hasattr(expression, "selects"): 62 groups = set(group.expressions) 63 group.meta[FINAL] = True 64 65 for e in expression.selects: 66 for node in e.walk(): 67 if node in groups: 68 e.meta[FINAL] = True 69 break 70 71 having = expression.args.get("having") 72 if having: 73 for node in having.walk(): 74 if node in groups: 75 having.meta[FINAL] = True 76 break 77 78 # Pre-order transformations 79 node = expression 80 node = rewrite_between(node) 81 node = uniq_sort(node, root) 82 node = absorb_and_eliminate(node, root) 83 node = simplify_concat(node) 84 node = simplify_conditionals(node) 85 86 if constant_propagation: 87 node = propagate_constants(node, root) 88 89 exp.replace_children(node, lambda e: _simplify(e, False)) 90 91 # Post-order transformations 92 node = simplify_not(node) 93 node = flatten(node) 94 node = simplify_connectors(node, root) 95 node = remove_complements(node, root) 96 node = simplify_coalesce(node) 97 node.parent = expression.parent 98 node = simplify_literals(node, root) 99 node = simplify_equality(node) 100 node = simplify_parens(node) 101 node = simplify_datetrunc(node, dialect) 102 node = sort_comparison(node) 103 node = simplify_startswith(node) 104 105 if root: 106 expression.replace(node) 107 return node 108 109 expression = while_changing(expression, _simplify) 110 remove_where_true(expression) 111 return expression 112 113 114def catch(*exceptions): 115 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 116 117 def decorator(func): 118 def wrapped(expression, *args, **kwargs): 119 try: 120 return func(expression, *args, **kwargs) 121 except exceptions: 122 return expression 123 124 return wrapped 125 126 return decorator 127 128 129def rewrite_between(expression: exp.Expression) -> exp.Expression: 130 """Rewrite x between y and z to x >= y AND x <= z. 131 132 This is done because comparison simplification is only done on lt/lte/gt/gte. 133 """ 134 if isinstance(expression, exp.Between): 135 negate = isinstance(expression.parent, exp.Not) 136 137 expression = exp.and_( 138 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 139 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 140 copy=False, 141 ) 142 143 if negate: 144 expression = exp.paren(expression, copy=False) 145 146 return expression 147 148 149COMPLEMENT_COMPARISONS = { 150 exp.LT: exp.GTE, 151 exp.GT: exp.LTE, 152 exp.LTE: exp.GT, 153 exp.GTE: exp.LT, 154 exp.EQ: exp.NEQ, 155 exp.NEQ: exp.EQ, 156} 157 158 159def simplify_not(expression): 160 """ 161 Demorgan's Law 162 NOT (x OR y) -> NOT x AND NOT y 163 NOT (x AND y) -> NOT x OR NOT y 164 """ 165 if isinstance(expression, exp.Not): 166 this = expression.this 167 if is_null(this): 168 return exp.null() 169 if this.__class__ in COMPLEMENT_COMPARISONS: 170 return COMPLEMENT_COMPARISONS[this.__class__]( 171 this=this.this, expression=this.expression 172 ) 173 if isinstance(this, exp.Paren): 174 condition = this.unnest() 175 if isinstance(condition, exp.And): 176 return exp.paren( 177 exp.or_( 178 exp.not_(condition.left, copy=False), 179 exp.not_(condition.right, copy=False), 180 copy=False, 181 ) 182 ) 183 if isinstance(condition, exp.Or): 184 return exp.paren( 185 exp.and_( 186 exp.not_(condition.left, copy=False), 187 exp.not_(condition.right, copy=False), 188 copy=False, 189 ) 190 ) 191 if is_null(condition): 192 return exp.null() 193 if always_true(this): 194 return exp.false() 195 if is_false(this): 196 return exp.true() 197 if isinstance(this, exp.Not): 198 # double negation 199 # NOT NOT x -> x 200 return this.this 201 return expression 202 203 204def flatten(expression): 205 """ 206 A AND (B AND C) -> A AND B AND C 207 A OR (B OR C) -> A OR B OR C 208 """ 209 if isinstance(expression, exp.Connector): 210 for node in expression.args.values(): 211 child = node.unnest() 212 if isinstance(child, expression.__class__): 213 node.replace(child) 214 return expression 215 216 217def simplify_connectors(expression, root=True): 218 def _simplify_connectors(expression, left, right): 219 if left == right: 220 return left 221 if isinstance(expression, exp.And): 222 if is_false(left) or is_false(right): 223 return exp.false() 224 if is_null(left) or is_null(right): 225 return exp.null() 226 if always_true(left) and always_true(right): 227 return exp.true() 228 if always_true(left): 229 return right 230 if always_true(right): 231 return left 232 return _simplify_comparison(expression, left, right) 233 elif isinstance(expression, exp.Or): 234 if always_true(left) or always_true(right): 235 return exp.true() 236 if is_false(left) and is_false(right): 237 return exp.false() 238 if ( 239 (is_null(left) and is_null(right)) 240 or (is_null(left) and is_false(right)) 241 or (is_false(left) and is_null(right)) 242 ): 243 return exp.null() 244 if is_false(left): 245 return right 246 if is_false(right): 247 return left 248 return _simplify_comparison(expression, left, right, or_=True) 249 250 if isinstance(expression, exp.Connector): 251 return _flat_simplify(expression, _simplify_connectors, root) 252 return expression 253 254 255LT_LTE = (exp.LT, exp.LTE) 256GT_GTE = (exp.GT, exp.GTE) 257 258COMPARISONS = ( 259 *LT_LTE, 260 *GT_GTE, 261 exp.EQ, 262 exp.NEQ, 263 exp.Is, 264) 265 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 267 exp.LT: exp.GT, 268 exp.GT: exp.LT, 269 exp.LTE: exp.GTE, 270 exp.GTE: exp.LTE, 271} 272 273NONDETERMINISTIC = (exp.Rand, exp.Randn) 274 275 276def _simplify_comparison(expression, left, right, or_=False): 277 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 278 ll, lr = left.args.values() 279 rl, rr = right.args.values() 280 281 largs = {ll, lr} 282 rargs = {rl, rr} 283 284 matching = largs & rargs 285 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 286 287 if matching and columns: 288 try: 289 l = first(largs - columns) 290 r = first(rargs - columns) 291 except StopIteration: 292 return expression 293 294 if l.is_number and r.is_number: 295 l = float(l.name) 296 r = float(r.name) 297 elif l.is_string and r.is_string: 298 l = l.name 299 r = r.name 300 else: 301 l = extract_date(l) 302 if not l: 303 return None 304 r = extract_date(r) 305 if not r: 306 return None 307 308 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 309 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 310 return left if (av > bv if or_ else av <= bv) else right 311 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 312 return left if (av < bv if or_ else av >= bv) else right 313 314 # we can't ever shortcut to true because the column could be null 315 if not or_: 316 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 317 if av <= bv: 318 return exp.false() 319 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 320 if av >= bv: 321 return exp.false() 322 elif isinstance(a, exp.EQ): 323 if isinstance(b, exp.LT): 324 return exp.false() if av >= bv else a 325 if isinstance(b, exp.LTE): 326 return exp.false() if av > bv else a 327 if isinstance(b, exp.GT): 328 return exp.false() if av <= bv else a 329 if isinstance(b, exp.GTE): 330 return exp.false() if av < bv else a 331 if isinstance(b, exp.NEQ): 332 return exp.false() if av == bv else a 333 return None 334 335 336def remove_complements(expression, root=True): 337 """ 338 Removing complements. 339 340 A AND NOT A -> FALSE 341 A OR NOT A -> TRUE 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if is_complement(a, b): 348 return complement 349 return expression 350 351 352def uniq_sort(expression, root=True): 353 """ 354 Uniq and sort a connector. 355 356 C AND A AND B AND B -> A AND B AND C 357 """ 358 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 359 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 360 flattened = tuple(expression.flatten()) 361 deduped = {gen(e): e for e in flattened} 362 arr = tuple(deduped.items()) 363 364 # check if the operands are already sorted, if not sort them 365 # A AND C AND B -> A AND B AND C 366 for i, (sql, e) in enumerate(arr[1:]): 367 if sql < arr[i][0]: 368 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 369 break 370 else: 371 # we didn't have to sort but maybe we need to dedup 372 if len(deduped) < len(flattened): 373 expression = result_func(*deduped.values(), copy=False) 374 375 return expression 376 377 378def absorb_and_eliminate(expression, root=True): 379 """ 380 absorption: 381 A AND (A OR B) -> A 382 A OR (A AND B) -> A 383 A AND (NOT A OR B) -> A AND B 384 A OR (NOT A AND B) -> A OR B 385 elimination: 386 (A AND B) OR (A AND NOT B) -> A 387 (A OR B) AND (A OR NOT B) -> A 388 """ 389 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 390 kind = exp.Or if isinstance(expression, exp.And) else exp.And 391 392 for a, b in itertools.permutations(expression.flatten(), 2): 393 if isinstance(a, kind): 394 aa, ab = a.unnest_operands() 395 396 # absorb 397 if is_complement(b, aa): 398 aa.replace(exp.true() if kind == exp.And else exp.false()) 399 elif is_complement(b, ab): 400 ab.replace(exp.true() if kind == exp.And else exp.false()) 401 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 402 a.replace(exp.false() if kind == exp.And else exp.true()) 403 elif isinstance(b, kind): 404 # eliminate 405 rhs = b.unnest_operands() 406 ba, bb = rhs 407 408 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 409 a.replace(aa) 410 b.replace(aa) 411 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 412 a.replace(ab) 413 b.replace(ab) 414 415 return expression 416 417 418def propagate_constants(expression, root=True): 419 """ 420 Propagate constants for conjunctions in DNF: 421 422 SELECT * FROM t WHERE a = b AND b = 5 becomes 423 SELECT * FROM t WHERE a = 5 AND b = 5 424 425 Reference: https://www.sqlite.org/optoverview.html 426 """ 427 428 if ( 429 isinstance(expression, exp.And) 430 and (root or not expression.same_parent) 431 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 432 ): 433 constant_mapping = {} 434 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 435 if isinstance(expr, exp.EQ): 436 l, r = expr.left, expr.right 437 438 # TODO: create a helper that can be used to detect nested literal expressions such 439 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 440 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 441 constant_mapping[l] = (id(l), r) 442 443 if constant_mapping: 444 for column in find_all_in_scope(expression, exp.Column): 445 parent = column.parent 446 column_id, constant = constant_mapping.get(column) or (None, None) 447 if ( 448 column_id is not None 449 and id(column) != column_id 450 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 451 ): 452 column.replace(constant.copy()) 453 454 return expression 455 456 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 458 exp.DateAdd: exp.Sub, 459 exp.DateSub: exp.Add, 460 exp.DatetimeAdd: exp.Sub, 461 exp.DatetimeSub: exp.Add, 462} 463 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 465 **INVERSE_DATE_OPS, 466 exp.Add: exp.Sub, 467 exp.Sub: exp.Add, 468} 469 470 471def _is_number(expression: exp.Expression) -> bool: 472 return expression.is_number 473 474 475def _is_interval(expression: exp.Expression) -> bool: 476 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 477 478 479@catch(ModuleNotFoundError, UnsupportedUnit) 480def simplify_equality(expression: exp.Expression) -> exp.Expression: 481 """ 482 Use the subtraction and addition properties of equality to simplify expressions: 483 484 x + 1 = 3 becomes x = 2 485 486 There are two binary operations in the above expression: + and = 487 Here's how we reference all the operands in the code below: 488 489 l r 490 x + 1 = 3 491 a b 492 """ 493 if isinstance(expression, COMPARISONS): 494 l, r = expression.left, expression.right 495 496 if l.__class__ not in INVERSE_OPS: 497 return expression 498 499 if r.is_number: 500 a_predicate = _is_number 501 b_predicate = _is_number 502 elif _is_date_literal(r): 503 a_predicate = _is_date_literal 504 b_predicate = _is_interval 505 else: 506 return expression 507 508 if l.__class__ in INVERSE_DATE_OPS: 509 l = t.cast(exp.IntervalOp, l) 510 a = l.this 511 b = l.interval() 512 else: 513 l = t.cast(exp.Binary, l) 514 a, b = l.left, l.right 515 516 if not a_predicate(a) and b_predicate(b): 517 pass 518 elif not a_predicate(b) and b_predicate(a): 519 a, b = b, a 520 else: 521 return expression 522 523 return expression.__class__( 524 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 525 ) 526 return expression 527 528 529def simplify_literals(expression, root=True): 530 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 531 return _flat_simplify(expression, _simplify_binary, root) 532 533 if isinstance(expression, exp.Neg): 534 this = expression.this 535 if this.is_number: 536 value = this.name 537 if value[0] == "-": 538 return exp.Literal.number(value[1:]) 539 return exp.Literal.number(f"-{value}") 540 541 if type(expression) in INVERSE_DATE_OPS: 542 return _simplify_binary(expression, expression.this, expression.interval()) or expression 543 544 return expression 545 546 547NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 548 549 550def _simplify_binary(expression, a, b): 551 if isinstance(expression, exp.Is): 552 if isinstance(b, exp.Not): 553 c = b.this 554 not_ = True 555 else: 556 c = b 557 not_ = False 558 559 if is_null(c): 560 if isinstance(a, exp.Literal): 561 return exp.true() if not_ else exp.false() 562 if is_null(a): 563 return exp.false() if not_ else exp.true() 564 elif isinstance(expression, NULL_OK): 565 return None 566 elif is_null(a) or is_null(b): 567 return exp.null() 568 569 if a.is_number and b.is_number: 570 num_a = int(a.name) if a.is_int else Decimal(a.name) 571 num_b = int(b.name) if b.is_int else Decimal(b.name) 572 573 if isinstance(expression, exp.Add): 574 return exp.Literal.number(num_a + num_b) 575 if isinstance(expression, exp.Mul): 576 return exp.Literal.number(num_a * num_b) 577 578 # We only simplify Sub, Div if a and b have the same parent because they're not associative 579 if isinstance(expression, exp.Sub): 580 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 581 if isinstance(expression, exp.Div): 582 # engines have differing int div behavior so intdiv is not safe 583 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 584 return None 585 return exp.Literal.number(num_a / num_b) 586 587 boolean = eval_boolean(expression, num_a, num_b) 588 589 if boolean: 590 return boolean 591 elif a.is_string and b.is_string: 592 boolean = eval_boolean(expression, a.this, b.this) 593 594 if boolean: 595 return boolean 596 elif _is_date_literal(a) and isinstance(b, exp.Interval): 597 a, b = extract_date(a), extract_interval(b) 598 if a and b: 599 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 600 return date_literal(a + b) 601 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 602 return date_literal(a - b) 603 elif isinstance(a, exp.Interval) and _is_date_literal(b): 604 a, b = extract_interval(a), extract_date(b) 605 # you cannot subtract a date from an interval 606 if a and b and isinstance(expression, exp.Add): 607 return date_literal(a + b) 608 elif _is_date_literal(a) and _is_date_literal(b): 609 if isinstance(expression, exp.Predicate): 610 a, b = extract_date(a), extract_date(b) 611 boolean = eval_boolean(expression, a, b) 612 if boolean: 613 return boolean 614 615 return None 616 617 618def simplify_parens(expression): 619 if not isinstance(expression, exp.Paren): 620 return expression 621 622 this = expression.this 623 parent = expression.parent 624 625 if not isinstance(this, exp.Select) and ( 626 not isinstance(parent, (exp.Condition, exp.Binary)) 627 or isinstance(parent, exp.Paren) 628 or not isinstance(this, exp.Binary) 629 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 630 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 631 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 632 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 633 ): 634 return this 635 return expression 636 637 638NONNULL_CONSTANTS = ( 639 exp.Literal, 640 exp.Boolean, 641) 642 643CONSTANTS = ( 644 exp.Literal, 645 exp.Boolean, 646 exp.Null, 647) 648 649 650def _is_nonnull_constant(expression: exp.Expression) -> bool: 651 return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) 652 653 654def _is_constant(expression: exp.Expression) -> bool: 655 return isinstance(expression, CONSTANTS) or _is_date_literal(expression) 656 657 658def simplify_coalesce(expression): 659 # COALESCE(x) -> x 660 if ( 661 isinstance(expression, exp.Coalesce) 662 and (not expression.expressions or _is_nonnull_constant(expression.this)) 663 # COALESCE is also used as a Spark partitioning hint 664 and not isinstance(expression.parent, exp.Hint) 665 ): 666 return expression.this 667 668 if not isinstance(expression, COMPARISONS): 669 return expression 670 671 if isinstance(expression.left, exp.Coalesce): 672 coalesce = expression.left 673 other = expression.right 674 elif isinstance(expression.right, exp.Coalesce): 675 coalesce = expression.right 676 other = expression.left 677 else: 678 return expression 679 680 # This transformation is valid for non-constants, 681 # but it really only does anything if they are both constants. 682 if not _is_constant(other): 683 return expression 684 685 # Find the first constant arg 686 for arg_index, arg in enumerate(coalesce.expressions): 687 if _is_constant(arg): 688 break 689 else: 690 return expression 691 692 coalesce.set("expressions", coalesce.expressions[:arg_index]) 693 694 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 695 # since we already remove COALESCE at the top of this function. 696 coalesce = coalesce if coalesce.expressions else coalesce.this 697 698 # This expression is more complex than when we started, but it will get simplified further 699 return exp.paren( 700 exp.or_( 701 exp.and_( 702 coalesce.is_(exp.null()).not_(copy=False), 703 expression.copy(), 704 copy=False, 705 ), 706 exp.and_( 707 coalesce.is_(exp.null()), 708 type(expression)(this=arg.copy(), expression=other.copy()), 709 copy=False, 710 ), 711 copy=False, 712 ) 713 ) 714 715 716CONCATS = (exp.Concat, exp.DPipe) 717 718 719def simplify_concat(expression): 720 """Reduces all groups that contain string literals by concatenating them.""" 721 if not isinstance(expression, CONCATS) or ( 722 # We can't reduce a CONCAT_WS call if we don't statically know the separator 723 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 724 ): 725 return expression 726 727 if isinstance(expression, exp.ConcatWs): 728 sep_expr, *expressions = expression.expressions 729 sep = sep_expr.name 730 concat_type = exp.ConcatWs 731 args = {} 732 else: 733 expressions = expression.expressions 734 sep = "" 735 concat_type = exp.Concat 736 args = { 737 "safe": expression.args.get("safe"), 738 "coalesce": expression.args.get("coalesce"), 739 } 740 741 new_args = [] 742 for is_string_group, group in itertools.groupby( 743 expressions or expression.flatten(), lambda e: e.is_string 744 ): 745 if is_string_group: 746 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 747 else: 748 new_args.extend(group) 749 750 if len(new_args) == 1 and new_args[0].is_string: 751 return new_args[0] 752 753 if concat_type is exp.ConcatWs: 754 new_args = [sep_expr] + new_args 755 756 return concat_type(expressions=new_args, **args) 757 758 759def simplify_conditionals(expression): 760 """Simplifies expressions like IF, CASE if their condition is statically known.""" 761 if isinstance(expression, exp.Case): 762 this = expression.this 763 for case in expression.args["ifs"]: 764 cond = case.this 765 if this: 766 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 767 cond = cond.replace(this.pop().eq(cond)) 768 769 if always_true(cond): 770 return case.args["true"] 771 772 if always_false(cond): 773 case.pop() 774 if not expression.args["ifs"]: 775 return expression.args.get("default") or exp.null() 776 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 777 if always_true(expression.this): 778 return expression.args["true"] 779 if always_false(expression.this): 780 return expression.args.get("false") or exp.null() 781 782 return expression 783 784 785def simplify_startswith(expression: exp.Expression) -> exp.Expression: 786 """ 787 Reduces a prefix check to either TRUE or FALSE if both the string and the 788 prefix are statically known. 789 790 Example: 791 >>> from sqlglot import parse_one 792 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 793 'TRUE' 794 """ 795 if ( 796 isinstance(expression, exp.StartsWith) 797 and expression.this.is_string 798 and expression.expression.is_string 799 ): 800 return exp.convert(expression.name.startswith(expression.expression.name)) 801 802 return expression 803 804 805DateRange = t.Tuple[datetime.date, datetime.date] 806 807 808def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 809 """ 810 Get the date range for a DATE_TRUNC equality comparison: 811 812 Example: 813 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 814 Returns: 815 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 816 """ 817 floor = date_floor(date, unit, dialect) 818 819 if date != floor: 820 # This will always be False, except for NULL values. 821 return None 822 823 return floor, floor + interval(unit) 824 825 826def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 827 """Get the logical expression for a date range""" 828 return exp.and_( 829 left >= date_literal(drange[0]), 830 left < date_literal(drange[1]), 831 copy=False, 832 ) 833 834 835def _datetrunc_eq( 836 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 837) -> t.Optional[exp.Expression]: 838 drange = _datetrunc_range(date, unit, dialect) 839 if not drange: 840 return None 841 842 return _datetrunc_eq_expression(left, drange) 843 844 845def _datetrunc_neq( 846 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 847) -> t.Optional[exp.Expression]: 848 drange = _datetrunc_range(date, unit, dialect) 849 if not drange: 850 return None 851 852 return exp.and_( 853 left < date_literal(drange[0]), 854 left >= date_literal(drange[1]), 855 copy=False, 856 ) 857 858 859DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 860 exp.LT: lambda l, dt, u, d: l 861 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 862 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 863 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 864 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 865 exp.EQ: _datetrunc_eq, 866 exp.NEQ: _datetrunc_neq, 867} 868DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 869DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 870 871 872def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 873 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 874 875 876@catch(ModuleNotFoundError, UnsupportedUnit) 877def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 878 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 879 comparison = expression.__class__ 880 881 if isinstance(expression, DATETRUNCS): 882 date = extract_date(expression.this) 883 if date and expression.unit: 884 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 885 elif comparison not in DATETRUNC_COMPARISONS: 886 return expression 887 888 if isinstance(expression, exp.Binary): 889 l, r = expression.left, expression.right 890 891 if not _is_datetrunc_predicate(l, r): 892 return expression 893 894 l = t.cast(exp.DateTrunc, l) 895 unit = l.unit.name.lower() 896 date = extract_date(r) 897 898 if not date: 899 return expression 900 901 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 902 elif isinstance(expression, exp.In): 903 l = expression.this 904 rs = expression.expressions 905 906 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 907 l = t.cast(exp.DateTrunc, l) 908 unit = l.unit.name.lower() 909 910 ranges = [] 911 for r in rs: 912 date = extract_date(r) 913 if not date: 914 return expression 915 drange = _datetrunc_range(date, unit, dialect) 916 if drange: 917 ranges.append(drange) 918 919 if not ranges: 920 return expression 921 922 ranges = merge_ranges(ranges) 923 924 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 925 926 return expression 927 928 929def sort_comparison(expression: exp.Expression) -> exp.Expression: 930 if expression.__class__ in COMPLEMENT_COMPARISONS: 931 l, r = expression.this, expression.expression 932 l_column = isinstance(l, exp.Column) 933 r_column = isinstance(r, exp.Column) 934 l_const = _is_constant(l) 935 r_const = _is_constant(r) 936 937 if (l_column and not r_column) or (r_const and not l_const): 938 return expression 939 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 940 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 941 this=r, expression=l 942 ) 943 return expression 944 945 946# CROSS joins result in an empty table if the right table is empty. 947# So we can only simplify certain types of joins to CROSS. 948# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 949JOINS = { 950 ("", ""), 951 ("", "INNER"), 952 ("RIGHT", ""), 953 ("RIGHT", "OUTER"), 954} 955 956 957def remove_where_true(expression): 958 for where in expression.find_all(exp.Where): 959 if always_true(where.this): 960 where.parent.set("where", None) 961 for join in expression.find_all(exp.Join): 962 if ( 963 always_true(join.args.get("on")) 964 and not join.args.get("using") 965 and not join.args.get("method") 966 and (join.side, join.kind) in JOINS 967 ): 968 join.set("on", None) 969 join.set("side", None) 970 join.set("kind", "CROSS") 971 972 973def always_true(expression): 974 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 975 expression, exp.Literal 976 ) 977 978 979def always_false(expression): 980 return is_false(expression) or is_null(expression) 981 982 983def is_complement(a, b): 984 return isinstance(b, exp.Not) and b.this == a 985 986 987def is_false(a: exp.Expression) -> bool: 988 return type(a) is exp.Boolean and not a.this 989 990 991def is_null(a: exp.Expression) -> bool: 992 return type(a) is exp.Null 993 994 995def eval_boolean(expression, a, b): 996 if isinstance(expression, (exp.EQ, exp.Is)): 997 return boolean_literal(a == b) 998 if isinstance(expression, exp.NEQ): 999 return boolean_literal(a != b) 1000 if isinstance(expression, exp.GT): 1001 return boolean_literal(a > b) 1002 if isinstance(expression, exp.GTE): 1003 return boolean_literal(a >= b) 1004 if isinstance(expression, exp.LT): 1005 return boolean_literal(a < b) 1006 if isinstance(expression, exp.LTE): 1007 return boolean_literal(a <= b) 1008 return None 1009 1010 1011def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1012 if isinstance(value, datetime.datetime): 1013 return value.date() 1014 if isinstance(value, datetime.date): 1015 return value 1016 try: 1017 return datetime.datetime.fromisoformat(value).date() 1018 except ValueError: 1019 return None 1020 1021 1022def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1023 if isinstance(value, datetime.datetime): 1024 return value 1025 if isinstance(value, datetime.date): 1026 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1027 try: 1028 return datetime.datetime.fromisoformat(value) 1029 except ValueError: 1030 return None 1031 1032 1033def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1034 if not value: 1035 return None 1036 if to.is_type(exp.DataType.Type.DATE): 1037 return cast_as_date(value) 1038 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1039 return cast_as_datetime(value) 1040 return None 1041 1042 1043def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1044 if isinstance(cast, exp.Cast): 1045 to = cast.to 1046 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1047 to = exp.DataType.build(exp.DataType.Type.DATE) 1048 else: 1049 return None 1050 1051 if isinstance(cast.this, exp.Literal): 1052 value: t.Any = cast.this.name 1053 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1054 value = extract_date(cast.this) 1055 else: 1056 return None 1057 return cast_value(value, to) 1058 1059 1060def _is_date_literal(expression: exp.Expression) -> bool: 1061 return extract_date(expression) is not None 1062 1063 1064def extract_interval(expression): 1065 try: 1066 n = int(expression.name) 1067 unit = expression.text("unit").lower() 1068 return interval(unit, n) 1069 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1070 return None 1071 1072 1073def date_literal(date): 1074 return exp.cast( 1075 exp.Literal.string(date), 1076 ( 1077 exp.DataType.Type.DATETIME 1078 if isinstance(date, datetime.datetime) 1079 else exp.DataType.Type.DATE 1080 ), 1081 ) 1082 1083 1084def interval(unit: str, n: int = 1): 1085 from dateutil.relativedelta import relativedelta 1086 1087 if unit == "year": 1088 return relativedelta(years=1 * n) 1089 if unit == "quarter": 1090 return relativedelta(months=3 * n) 1091 if unit == "month": 1092 return relativedelta(months=1 * n) 1093 if unit == "week": 1094 return relativedelta(weeks=1 * n) 1095 if unit == "day": 1096 return relativedelta(days=1 * n) 1097 if unit == "hour": 1098 return relativedelta(hours=1 * n) 1099 if unit == "minute": 1100 return relativedelta(minutes=1 * n) 1101 if unit == "second": 1102 return relativedelta(seconds=1 * n) 1103 1104 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1105 1106 1107def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1108 if unit == "year": 1109 return d.replace(month=1, day=1) 1110 if unit == "quarter": 1111 if d.month <= 3: 1112 return d.replace(month=1, day=1) 1113 elif d.month <= 6: 1114 return d.replace(month=4, day=1) 1115 elif d.month <= 9: 1116 return d.replace(month=7, day=1) 1117 else: 1118 return d.replace(month=10, day=1) 1119 if unit == "month": 1120 return d.replace(month=d.month, day=1) 1121 if unit == "week": 1122 # Assuming week starts on Monday (0) and ends on Sunday (6) 1123 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1124 if unit == "day": 1125 return d 1126 1127 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1128 1129 1130def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1131 floor = date_floor(d, unit, dialect) 1132 1133 if floor == d: 1134 return d 1135 1136 return floor + interval(unit) 1137 1138 1139def boolean_literal(condition): 1140 return exp.true() if condition else exp.false() 1141 1142 1143def _flat_simplify(expression, simplifier, root=True): 1144 if root or not expression.same_parent: 1145 operands = [] 1146 queue = deque(expression.flatten(unnest=False)) 1147 size = len(queue) 1148 1149 while queue: 1150 a = queue.popleft() 1151 1152 for b in queue: 1153 result = simplifier(expression, a, b) 1154 1155 if result and result is not expression: 1156 queue.remove(b) 1157 queue.appendleft(result) 1158 break 1159 else: 1160 operands.append(a) 1161 1162 if len(operands) < size: 1163 return functools.reduce( 1164 lambda a, b: expression.__class__(this=a, expression=b), operands 1165 ) 1166 return expression 1167 1168 1169def gen(expression: t.Any) -> str: 1170 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1171 1172 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1173 generator is expensive so we have a bare minimum sql generator here. 1174 """ 1175 return Gen().gen(expression) 1176 1177 1178class Gen: 1179 def __init__(self): 1180 self.stack = [] 1181 self.sqls = [] 1182 1183 def gen(self, expression: exp.Expression) -> str: 1184 self.stack = [expression] 1185 self.sqls.clear() 1186 1187 while self.stack: 1188 node = self.stack.pop() 1189 1190 if isinstance(node, exp.Expression): 1191 exp_handler_name = f"{node.key}_sql" 1192 1193 if hasattr(self, exp_handler_name): 1194 getattr(self, exp_handler_name)(node) 1195 elif isinstance(node, exp.Func): 1196 self._function(node) 1197 else: 1198 key = node.key.upper() 1199 self.stack.append(f"{key} " if self._args(node) else key) 1200 elif type(node) is list: 1201 for n in reversed(node): 1202 if n is not None: 1203 self.stack.extend((n, ",")) 1204 if node: 1205 self.stack.pop() 1206 else: 1207 if node is not None: 1208 self.sqls.append(str(node)) 1209 1210 return "".join(self.sqls) 1211 1212 def add_sql(self, e: exp.Add) -> None: 1213 self._binary(e, " + ") 1214 1215 def alias_sql(self, e: exp.Alias) -> None: 1216 self.stack.extend( 1217 ( 1218 e.args.get("alias"), 1219 " AS ", 1220 e.args.get("this"), 1221 ) 1222 ) 1223 1224 def and_sql(self, e: exp.And) -> None: 1225 self._binary(e, " AND ") 1226 1227 def anonymous_sql(self, e: exp.Anonymous) -> None: 1228 this = e.this 1229 if isinstance(this, str): 1230 name = this.upper() 1231 elif isinstance(this, exp.Identifier): 1232 name = this.this 1233 name = f'"{name}"' if this.quoted else name.upper() 1234 else: 1235 raise ValueError( 1236 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1237 ) 1238 1239 self.stack.extend( 1240 ( 1241 ")", 1242 e.expressions, 1243 "(", 1244 name, 1245 ) 1246 ) 1247 1248 def between_sql(self, e: exp.Between) -> None: 1249 self.stack.extend( 1250 ( 1251 e.args.get("high"), 1252 " AND ", 1253 e.args.get("low"), 1254 " BETWEEN ", 1255 e.this, 1256 ) 1257 ) 1258 1259 def boolean_sql(self, e: exp.Boolean) -> None: 1260 self.stack.append("TRUE" if e.this else "FALSE") 1261 1262 def bracket_sql(self, e: exp.Bracket) -> None: 1263 self.stack.extend( 1264 ( 1265 "]", 1266 e.expressions, 1267 "[", 1268 e.this, 1269 ) 1270 ) 1271 1272 def column_sql(self, e: exp.Column) -> None: 1273 for p in reversed(e.parts): 1274 self.stack.extend((p, ".")) 1275 self.stack.pop() 1276 1277 def datatype_sql(self, e: exp.DataType) -> None: 1278 self._args(e, 1) 1279 self.stack.append(f"{e.this.name} ") 1280 1281 def div_sql(self, e: exp.Div) -> None: 1282 self._binary(e, " / ") 1283 1284 def dot_sql(self, e: exp.Dot) -> None: 1285 self._binary(e, ".") 1286 1287 def eq_sql(self, e: exp.EQ) -> None: 1288 self._binary(e, " = ") 1289 1290 def from_sql(self, e: exp.From) -> None: 1291 self.stack.extend((e.this, "FROM ")) 1292 1293 def gt_sql(self, e: exp.GT) -> None: 1294 self._binary(e, " > ") 1295 1296 def gte_sql(self, e: exp.GTE) -> None: 1297 self._binary(e, " >= ") 1298 1299 def identifier_sql(self, e: exp.Identifier) -> None: 1300 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1301 1302 def ilike_sql(self, e: exp.ILike) -> None: 1303 self._binary(e, " ILIKE ") 1304 1305 def in_sql(self, e: exp.In) -> None: 1306 self.stack.append(")") 1307 self._args(e, 1) 1308 self.stack.extend( 1309 ( 1310 "(", 1311 " IN ", 1312 e.this, 1313 ) 1314 ) 1315 1316 def intdiv_sql(self, e: exp.IntDiv) -> None: 1317 self._binary(e, " DIV ") 1318 1319 def is_sql(self, e: exp.Is) -> None: 1320 self._binary(e, " IS ") 1321 1322 def like_sql(self, e: exp.Like) -> None: 1323 self._binary(e, " Like ") 1324 1325 def literal_sql(self, e: exp.Literal) -> None: 1326 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1327 1328 def lt_sql(self, e: exp.LT) -> None: 1329 self._binary(e, " < ") 1330 1331 def lte_sql(self, e: exp.LTE) -> None: 1332 self._binary(e, " <= ") 1333 1334 def mod_sql(self, e: exp.Mod) -> None: 1335 self._binary(e, " % ") 1336 1337 def mul_sql(self, e: exp.Mul) -> None: 1338 self._binary(e, " * ") 1339 1340 def neg_sql(self, e: exp.Neg) -> None: 1341 self._unary(e, "-") 1342 1343 def neq_sql(self, e: exp.NEQ) -> None: 1344 self._binary(e, " <> ") 1345 1346 def not_sql(self, e: exp.Not) -> None: 1347 self._unary(e, "NOT ") 1348 1349 def null_sql(self, e: exp.Null) -> None: 1350 self.stack.append("NULL") 1351 1352 def or_sql(self, e: exp.Or) -> None: 1353 self._binary(e, " OR ") 1354 1355 def paren_sql(self, e: exp.Paren) -> None: 1356 self.stack.extend( 1357 ( 1358 ")", 1359 e.this, 1360 "(", 1361 ) 1362 ) 1363 1364 def sub_sql(self, e: exp.Sub) -> None: 1365 self._binary(e, " - ") 1366 1367 def subquery_sql(self, e: exp.Subquery) -> None: 1368 self._args(e, 2) 1369 alias = e.args.get("alias") 1370 if alias: 1371 self.stack.append(alias) 1372 self.stack.extend((")", e.this, "(")) 1373 1374 def table_sql(self, e: exp.Table) -> None: 1375 self._args(e, 4) 1376 alias = e.args.get("alias") 1377 if alias: 1378 self.stack.append(alias) 1379 for p in reversed(e.parts): 1380 self.stack.extend((p, ".")) 1381 self.stack.pop() 1382 1383 def tablealias_sql(self, e: exp.TableAlias) -> None: 1384 columns = e.columns 1385 1386 if columns: 1387 self.stack.extend((")", columns, "(")) 1388 1389 self.stack.extend((e.this, " AS ")) 1390 1391 def var_sql(self, e: exp.Var) -> None: 1392 self.stack.append(e.this) 1393 1394 def _binary(self, e: exp.Binary, op: str) -> None: 1395 self.stack.extend((e.expression, op, e.this)) 1396 1397 def _unary(self, e: exp.Unary, op: str) -> None: 1398 self.stack.extend((e.this, op)) 1399 1400 def _function(self, e: exp.Func) -> None: 1401 self.stack.extend( 1402 ( 1403 ")", 1404 list(e.args.values()), 1405 "(", 1406 e.sql_name(), 1407 ) 1408 ) 1409 1410 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1411 kvs = [] 1412 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1413 1414 for k in arg_types or arg_types: 1415 v = node.args.get(k) 1416 1417 if v is not None: 1418 kvs.append([f":{k}", v]) 1419 if kvs: 1420 self.stack.append(kvs) 1421 return True 1422 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
31def simplify( 32 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 33): 34 """ 35 Rewrite sqlglot AST to simplify expressions. 36 37 Example: 38 >>> import sqlglot 39 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 40 >>> simplify(expression).sql() 41 'TRUE' 42 43 Args: 44 expression (sqlglot.Expression): expression to simplify 45 constant_propagation: whether the constant propagation rule should be used 46 47 Returns: 48 sqlglot.Expression: simplified expression 49 """ 50 51 dialect = Dialect.get_or_raise(dialect) 52 53 def _simplify(expression, root=True): 54 if expression.meta.get(FINAL): 55 return expression 56 57 # group by expressions cannot be simplified, for example 58 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 59 # the projection must exactly match the group by key 60 group = expression.args.get("group") 61 62 if group and hasattr(expression, "selects"): 63 groups = set(group.expressions) 64 group.meta[FINAL] = True 65 66 for e in expression.selects: 67 for node in e.walk(): 68 if node in groups: 69 e.meta[FINAL] = True 70 break 71 72 having = expression.args.get("having") 73 if having: 74 for node in having.walk(): 75 if node in groups: 76 having.meta[FINAL] = True 77 break 78 79 # Pre-order transformations 80 node = expression 81 node = rewrite_between(node) 82 node = uniq_sort(node, root) 83 node = absorb_and_eliminate(node, root) 84 node = simplify_concat(node) 85 node = simplify_conditionals(node) 86 87 if constant_propagation: 88 node = propagate_constants(node, root) 89 90 exp.replace_children(node, lambda e: _simplify(e, False)) 91 92 # Post-order transformations 93 node = simplify_not(node) 94 node = flatten(node) 95 node = simplify_connectors(node, root) 96 node = remove_complements(node, root) 97 node = simplify_coalesce(node) 98 node.parent = expression.parent 99 node = simplify_literals(node, root) 100 node = simplify_equality(node) 101 node = simplify_parens(node) 102 node = simplify_datetrunc(node, dialect) 103 node = sort_comparison(node) 104 node = simplify_startswith(node) 105 106 if root: 107 expression.replace(node) 108 return node 109 110 expression = while_changing(expression, _simplify) 111 remove_where_true(expression) 112 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
115def catch(*exceptions): 116 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 117 118 def decorator(func): 119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression 124 125 return wrapped 126 127 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
130def rewrite_between(expression: exp.Expression) -> exp.Expression: 131 """Rewrite x between y and z to x >= y AND x <= z. 132 133 This is done because comparison simplification is only done on lt/lte/gt/gte. 134 """ 135 if isinstance(expression, exp.Between): 136 negate = isinstance(expression.parent, exp.Not) 137 138 expression = exp.and_( 139 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 140 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 141 copy=False, 142 ) 143 144 if negate: 145 expression = exp.paren(expression, copy=False) 146 147 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
160def simplify_not(expression): 161 """ 162 Demorgan's Law 163 NOT (x OR y) -> NOT x AND NOT y 164 NOT (x AND y) -> NOT x OR NOT y 165 """ 166 if isinstance(expression, exp.Not): 167 this = expression.this 168 if is_null(this): 169 return exp.null() 170 if this.__class__ in COMPLEMENT_COMPARISONS: 171 return COMPLEMENT_COMPARISONS[this.__class__]( 172 this=this.this, expression=this.expression 173 ) 174 if isinstance(this, exp.Paren): 175 condition = this.unnest() 176 if isinstance(condition, exp.And): 177 return exp.paren( 178 exp.or_( 179 exp.not_(condition.left, copy=False), 180 exp.not_(condition.right, copy=False), 181 copy=False, 182 ) 183 ) 184 if isinstance(condition, exp.Or): 185 return exp.paren( 186 exp.and_( 187 exp.not_(condition.left, copy=False), 188 exp.not_(condition.right, copy=False), 189 copy=False, 190 ) 191 ) 192 if is_null(condition): 193 return exp.null() 194 if always_true(this): 195 return exp.false() 196 if is_false(this): 197 return exp.true() 198 if isinstance(this, exp.Not): 199 # double negation 200 # NOT NOT x -> x 201 return this.this 202 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
205def flatten(expression): 206 """ 207 A AND (B AND C) -> A AND B AND C 208 A OR (B OR C) -> A OR B OR C 209 """ 210 if isinstance(expression, exp.Connector): 211 for node in expression.args.values(): 212 child = node.unnest() 213 if isinstance(child, expression.__class__): 214 node.replace(child) 215 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
218def simplify_connectors(expression, root=True): 219 def _simplify_connectors(expression, left, right): 220 if left == right: 221 return left 222 if isinstance(expression, exp.And): 223 if is_false(left) or is_false(right): 224 return exp.false() 225 if is_null(left) or is_null(right): 226 return exp.null() 227 if always_true(left) and always_true(right): 228 return exp.true() 229 if always_true(left): 230 return right 231 if always_true(right): 232 return left 233 return _simplify_comparison(expression, left, right) 234 elif isinstance(expression, exp.Or): 235 if always_true(left) or always_true(right): 236 return exp.true() 237 if is_false(left) and is_false(right): 238 return exp.false() 239 if ( 240 (is_null(left) and is_null(right)) 241 or (is_null(left) and is_false(right)) 242 or (is_false(left) and is_null(right)) 243 ): 244 return exp.null() 245 if is_false(left): 246 return right 247 if is_false(right): 248 return left 249 return _simplify_comparison(expression, left, right, or_=True) 250 251 if isinstance(expression, exp.Connector): 252 return _flat_simplify(expression, _simplify_connectors, root) 253 return expression
337def remove_complements(expression, root=True): 338 """ 339 Removing complements. 340 341 A AND NOT A -> FALSE 342 A OR NOT A -> TRUE 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if is_complement(a, b): 349 return complement 350 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
353def uniq_sort(expression, root=True): 354 """ 355 Uniq and sort a connector. 356 357 C AND A AND B AND B -> A AND B AND C 358 """ 359 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 360 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 361 flattened = tuple(expression.flatten()) 362 deduped = {gen(e): e for e in flattened} 363 arr = tuple(deduped.items()) 364 365 # check if the operands are already sorted, if not sort them 366 # A AND C AND B -> A AND B AND C 367 for i, (sql, e) in enumerate(arr[1:]): 368 if sql < arr[i][0]: 369 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 370 break 371 else: 372 # we didn't have to sort but maybe we need to dedup 373 if len(deduped) < len(flattened): 374 expression = result_func(*deduped.values(), copy=False) 375 376 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
379def absorb_and_eliminate(expression, root=True): 380 """ 381 absorption: 382 A AND (A OR B) -> A 383 A OR (A AND B) -> A 384 A AND (NOT A OR B) -> A AND B 385 A OR (NOT A AND B) -> A OR B 386 elimination: 387 (A AND B) OR (A AND NOT B) -> A 388 (A OR B) AND (A OR NOT B) -> A 389 """ 390 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 391 kind = exp.Or if isinstance(expression, exp.And) else exp.And 392 393 for a, b in itertools.permutations(expression.flatten(), 2): 394 if isinstance(a, kind): 395 aa, ab = a.unnest_operands() 396 397 # absorb 398 if is_complement(b, aa): 399 aa.replace(exp.true() if kind == exp.And else exp.false()) 400 elif is_complement(b, ab): 401 ab.replace(exp.true() if kind == exp.And else exp.false()) 402 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 403 a.replace(exp.false() if kind == exp.And else exp.true()) 404 elif isinstance(b, kind): 405 # eliminate 406 rhs = b.unnest_operands() 407 ba, bb = rhs 408 409 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 410 a.replace(aa) 411 b.replace(aa) 412 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 413 a.replace(ab) 414 b.replace(ab) 415 416 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
419def propagate_constants(expression, root=True): 420 """ 421 Propagate constants for conjunctions in DNF: 422 423 SELECT * FROM t WHERE a = b AND b = 5 becomes 424 SELECT * FROM t WHERE a = 5 AND b = 5 425 426 Reference: https://www.sqlite.org/optoverview.html 427 """ 428 429 if ( 430 isinstance(expression, exp.And) 431 and (root or not expression.same_parent) 432 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 433 ): 434 constant_mapping = {} 435 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 436 if isinstance(expr, exp.EQ): 437 l, r = expr.left, expr.right 438 439 # TODO: create a helper that can be used to detect nested literal expressions such 440 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 441 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 442 constant_mapping[l] = (id(l), r) 443 444 if constant_mapping: 445 for column in find_all_in_scope(expression, exp.Column): 446 parent = column.parent 447 column_id, constant = constant_mapping.get(column) or (None, None) 448 if ( 449 column_id is not None 450 and id(column) != column_id 451 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 452 ): 453 column.replace(constant.copy()) 454 455 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
530def simplify_literals(expression, root=True): 531 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 532 return _flat_simplify(expression, _simplify_binary, root) 533 534 if isinstance(expression, exp.Neg): 535 this = expression.this 536 if this.is_number: 537 value = this.name 538 if value[0] == "-": 539 return exp.Literal.number(value[1:]) 540 return exp.Literal.number(f"-{value}") 541 542 if type(expression) in INVERSE_DATE_OPS: 543 return _simplify_binary(expression, expression.this, expression.interval()) or expression 544 545 return expression
619def simplify_parens(expression): 620 if not isinstance(expression, exp.Paren): 621 return expression 622 623 this = expression.this 624 parent = expression.parent 625 626 if not isinstance(this, exp.Select) and ( 627 not isinstance(parent, (exp.Condition, exp.Binary)) 628 or isinstance(parent, exp.Paren) 629 or not isinstance(this, exp.Binary) 630 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 631 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 632 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 633 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 634 ): 635 return this 636 return expression
659def simplify_coalesce(expression): 660 # COALESCE(x) -> x 661 if ( 662 isinstance(expression, exp.Coalesce) 663 and (not expression.expressions or _is_nonnull_constant(expression.this)) 664 # COALESCE is also used as a Spark partitioning hint 665 and not isinstance(expression.parent, exp.Hint) 666 ): 667 return expression.this 668 669 if not isinstance(expression, COMPARISONS): 670 return expression 671 672 if isinstance(expression.left, exp.Coalesce): 673 coalesce = expression.left 674 other = expression.right 675 elif isinstance(expression.right, exp.Coalesce): 676 coalesce = expression.right 677 other = expression.left 678 else: 679 return expression 680 681 # This transformation is valid for non-constants, 682 # but it really only does anything if they are both constants. 683 if not _is_constant(other): 684 return expression 685 686 # Find the first constant arg 687 for arg_index, arg in enumerate(coalesce.expressions): 688 if _is_constant(arg): 689 break 690 else: 691 return expression 692 693 coalesce.set("expressions", coalesce.expressions[:arg_index]) 694 695 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 696 # since we already remove COALESCE at the top of this function. 697 coalesce = coalesce if coalesce.expressions else coalesce.this 698 699 # This expression is more complex than when we started, but it will get simplified further 700 return exp.paren( 701 exp.or_( 702 exp.and_( 703 coalesce.is_(exp.null()).not_(copy=False), 704 expression.copy(), 705 copy=False, 706 ), 707 exp.and_( 708 coalesce.is_(exp.null()), 709 type(expression)(this=arg.copy(), expression=other.copy()), 710 copy=False, 711 ), 712 copy=False, 713 ) 714 )
720def simplify_concat(expression): 721 """Reduces all groups that contain string literals by concatenating them.""" 722 if not isinstance(expression, CONCATS) or ( 723 # We can't reduce a CONCAT_WS call if we don't statically know the separator 724 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 725 ): 726 return expression 727 728 if isinstance(expression, exp.ConcatWs): 729 sep_expr, *expressions = expression.expressions 730 sep = sep_expr.name 731 concat_type = exp.ConcatWs 732 args = {} 733 else: 734 expressions = expression.expressions 735 sep = "" 736 concat_type = exp.Concat 737 args = { 738 "safe": expression.args.get("safe"), 739 "coalesce": expression.args.get("coalesce"), 740 } 741 742 new_args = [] 743 for is_string_group, group in itertools.groupby( 744 expressions or expression.flatten(), lambda e: e.is_string 745 ): 746 if is_string_group: 747 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 748 else: 749 new_args.extend(group) 750 751 if len(new_args) == 1 and new_args[0].is_string: 752 return new_args[0] 753 754 if concat_type is exp.ConcatWs: 755 new_args = [sep_expr] + new_args 756 757 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
760def simplify_conditionals(expression): 761 """Simplifies expressions like IF, CASE if their condition is statically known.""" 762 if isinstance(expression, exp.Case): 763 this = expression.this 764 for case in expression.args["ifs"]: 765 cond = case.this 766 if this: 767 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 768 cond = cond.replace(this.pop().eq(cond)) 769 770 if always_true(cond): 771 return case.args["true"] 772 773 if always_false(cond): 774 case.pop() 775 if not expression.args["ifs"]: 776 return expression.args.get("default") or exp.null() 777 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 778 if always_true(expression.this): 779 return expression.args["true"] 780 if always_false(expression.this): 781 return expression.args.get("false") or exp.null() 782 783 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
786def simplify_startswith(expression: exp.Expression) -> exp.Expression: 787 """ 788 Reduces a prefix check to either TRUE or FALSE if both the string and the 789 prefix are statically known. 790 791 Example: 792 >>> from sqlglot import parse_one 793 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 794 'TRUE' 795 """ 796 if ( 797 isinstance(expression, exp.StartsWith) 798 and expression.this.is_string 799 and expression.expression.is_string 800 ): 801 return exp.convert(expression.name.startswith(expression.expression.name)) 802 803 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
930def sort_comparison(expression: exp.Expression) -> exp.Expression: 931 if expression.__class__ in COMPLEMENT_COMPARISONS: 932 l, r = expression.this, expression.expression 933 l_column = isinstance(l, exp.Column) 934 r_column = isinstance(r, exp.Column) 935 l_const = _is_constant(l) 936 r_const = _is_constant(r) 937 938 if (l_column and not r_column) or (r_const and not l_const): 939 return expression 940 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 941 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 942 this=r, expression=l 943 ) 944 return expression
958def remove_where_true(expression): 959 for where in expression.find_all(exp.Where): 960 if always_true(where.this): 961 where.parent.set("where", None) 962 for join in expression.find_all(exp.Join): 963 if ( 964 always_true(join.args.get("on")) 965 and not join.args.get("using") 966 and not join.args.get("method") 967 and (join.side, join.kind) in JOINS 968 ): 969 join.set("on", None) 970 join.set("side", None) 971 join.set("kind", "CROSS")
996def eval_boolean(expression, a, b): 997 if isinstance(expression, (exp.EQ, exp.Is)): 998 return boolean_literal(a == b) 999 if isinstance(expression, exp.NEQ): 1000 return boolean_literal(a != b) 1001 if isinstance(expression, exp.GT): 1002 return boolean_literal(a > b) 1003 if isinstance(expression, exp.GTE): 1004 return boolean_literal(a >= b) 1005 if isinstance(expression, exp.LT): 1006 return boolean_literal(a < b) 1007 if isinstance(expression, exp.LTE): 1008 return boolean_literal(a <= b) 1009 return None
1012def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1013 if isinstance(value, datetime.datetime): 1014 return value.date() 1015 if isinstance(value, datetime.date): 1016 return value 1017 try: 1018 return datetime.datetime.fromisoformat(value).date() 1019 except ValueError: 1020 return None
1023def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1024 if isinstance(value, datetime.datetime): 1025 return value 1026 if isinstance(value, datetime.date): 1027 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1028 try: 1029 return datetime.datetime.fromisoformat(value) 1030 except ValueError: 1031 return None
1034def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1035 if not value: 1036 return None 1037 if to.is_type(exp.DataType.Type.DATE): 1038 return cast_as_date(value) 1039 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1040 return cast_as_datetime(value) 1041 return None
1044def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1045 if isinstance(cast, exp.Cast): 1046 to = cast.to 1047 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1048 to = exp.DataType.build(exp.DataType.Type.DATE) 1049 else: 1050 return None 1051 1052 if isinstance(cast.this, exp.Literal): 1053 value: t.Any = cast.this.name 1054 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1055 value = extract_date(cast.this) 1056 else: 1057 return None 1058 return cast_value(value, to)
1085def interval(unit: str, n: int = 1): 1086 from dateutil.relativedelta import relativedelta 1087 1088 if unit == "year": 1089 return relativedelta(years=1 * n) 1090 if unit == "quarter": 1091 return relativedelta(months=3 * n) 1092 if unit == "month": 1093 return relativedelta(months=1 * n) 1094 if unit == "week": 1095 return relativedelta(weeks=1 * n) 1096 if unit == "day": 1097 return relativedelta(days=1 * n) 1098 if unit == "hour": 1099 return relativedelta(hours=1 * n) 1100 if unit == "minute": 1101 return relativedelta(minutes=1 * n) 1102 if unit == "second": 1103 return relativedelta(seconds=1 * n) 1104 1105 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1108def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1109 if unit == "year": 1110 return d.replace(month=1, day=1) 1111 if unit == "quarter": 1112 if d.month <= 3: 1113 return d.replace(month=1, day=1) 1114 elif d.month <= 6: 1115 return d.replace(month=4, day=1) 1116 elif d.month <= 9: 1117 return d.replace(month=7, day=1) 1118 else: 1119 return d.replace(month=10, day=1) 1120 if unit == "month": 1121 return d.replace(month=d.month, day=1) 1122 if unit == "week": 1123 # Assuming week starts on Monday (0) and ends on Sunday (6) 1124 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1125 if unit == "day": 1126 return d 1127 1128 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1170def gen(expression: t.Any) -> str: 1171 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1172 1173 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1174 generator is expensive so we have a bare minimum sql generator here. 1175 """ 1176 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1179class Gen: 1180 def __init__(self): 1181 self.stack = [] 1182 self.sqls = [] 1183 1184 def gen(self, expression: exp.Expression) -> str: 1185 self.stack = [expression] 1186 self.sqls.clear() 1187 1188 while self.stack: 1189 node = self.stack.pop() 1190 1191 if isinstance(node, exp.Expression): 1192 exp_handler_name = f"{node.key}_sql" 1193 1194 if hasattr(self, exp_handler_name): 1195 getattr(self, exp_handler_name)(node) 1196 elif isinstance(node, exp.Func): 1197 self._function(node) 1198 else: 1199 key = node.key.upper() 1200 self.stack.append(f"{key} " if self._args(node) else key) 1201 elif type(node) is list: 1202 for n in reversed(node): 1203 if n is not None: 1204 self.stack.extend((n, ",")) 1205 if node: 1206 self.stack.pop() 1207 else: 1208 if node is not None: 1209 self.sqls.append(str(node)) 1210 1211 return "".join(self.sqls) 1212 1213 def add_sql(self, e: exp.Add) -> None: 1214 self._binary(e, " + ") 1215 1216 def alias_sql(self, e: exp.Alias) -> None: 1217 self.stack.extend( 1218 ( 1219 e.args.get("alias"), 1220 " AS ", 1221 e.args.get("this"), 1222 ) 1223 ) 1224 1225 def and_sql(self, e: exp.And) -> None: 1226 self._binary(e, " AND ") 1227 1228 def anonymous_sql(self, e: exp.Anonymous) -> None: 1229 this = e.this 1230 if isinstance(this, str): 1231 name = this.upper() 1232 elif isinstance(this, exp.Identifier): 1233 name = this.this 1234 name = f'"{name}"' if this.quoted else name.upper() 1235 else: 1236 raise ValueError( 1237 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1238 ) 1239 1240 self.stack.extend( 1241 ( 1242 ")", 1243 e.expressions, 1244 "(", 1245 name, 1246 ) 1247 ) 1248 1249 def between_sql(self, e: exp.Between) -> None: 1250 self.stack.extend( 1251 ( 1252 e.args.get("high"), 1253 " AND ", 1254 e.args.get("low"), 1255 " BETWEEN ", 1256 e.this, 1257 ) 1258 ) 1259 1260 def boolean_sql(self, e: exp.Boolean) -> None: 1261 self.stack.append("TRUE" if e.this else "FALSE") 1262 1263 def bracket_sql(self, e: exp.Bracket) -> None: 1264 self.stack.extend( 1265 ( 1266 "]", 1267 e.expressions, 1268 "[", 1269 e.this, 1270 ) 1271 ) 1272 1273 def column_sql(self, e: exp.Column) -> None: 1274 for p in reversed(e.parts): 1275 self.stack.extend((p, ".")) 1276 self.stack.pop() 1277 1278 def datatype_sql(self, e: exp.DataType) -> None: 1279 self._args(e, 1) 1280 self.stack.append(f"{e.this.name} ") 1281 1282 def div_sql(self, e: exp.Div) -> None: 1283 self._binary(e, " / ") 1284 1285 def dot_sql(self, e: exp.Dot) -> None: 1286 self._binary(e, ".") 1287 1288 def eq_sql(self, e: exp.EQ) -> None: 1289 self._binary(e, " = ") 1290 1291 def from_sql(self, e: exp.From) -> None: 1292 self.stack.extend((e.this, "FROM ")) 1293 1294 def gt_sql(self, e: exp.GT) -> None: 1295 self._binary(e, " > ") 1296 1297 def gte_sql(self, e: exp.GTE) -> None: 1298 self._binary(e, " >= ") 1299 1300 def identifier_sql(self, e: exp.Identifier) -> None: 1301 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1302 1303 def ilike_sql(self, e: exp.ILike) -> None: 1304 self._binary(e, " ILIKE ") 1305 1306 def in_sql(self, e: exp.In) -> None: 1307 self.stack.append(")") 1308 self._args(e, 1) 1309 self.stack.extend( 1310 ( 1311 "(", 1312 " IN ", 1313 e.this, 1314 ) 1315 ) 1316 1317 def intdiv_sql(self, e: exp.IntDiv) -> None: 1318 self._binary(e, " DIV ") 1319 1320 def is_sql(self, e: exp.Is) -> None: 1321 self._binary(e, " IS ") 1322 1323 def like_sql(self, e: exp.Like) -> None: 1324 self._binary(e, " Like ") 1325 1326 def literal_sql(self, e: exp.Literal) -> None: 1327 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1328 1329 def lt_sql(self, e: exp.LT) -> None: 1330 self._binary(e, " < ") 1331 1332 def lte_sql(self, e: exp.LTE) -> None: 1333 self._binary(e, " <= ") 1334 1335 def mod_sql(self, e: exp.Mod) -> None: 1336 self._binary(e, " % ") 1337 1338 def mul_sql(self, e: exp.Mul) -> None: 1339 self._binary(e, " * ") 1340 1341 def neg_sql(self, e: exp.Neg) -> None: 1342 self._unary(e, "-") 1343 1344 def neq_sql(self, e: exp.NEQ) -> None: 1345 self._binary(e, " <> ") 1346 1347 def not_sql(self, e: exp.Not) -> None: 1348 self._unary(e, "NOT ") 1349 1350 def null_sql(self, e: exp.Null) -> None: 1351 self.stack.append("NULL") 1352 1353 def or_sql(self, e: exp.Or) -> None: 1354 self._binary(e, " OR ") 1355 1356 def paren_sql(self, e: exp.Paren) -> None: 1357 self.stack.extend( 1358 ( 1359 ")", 1360 e.this, 1361 "(", 1362 ) 1363 ) 1364 1365 def sub_sql(self, e: exp.Sub) -> None: 1366 self._binary(e, " - ") 1367 1368 def subquery_sql(self, e: exp.Subquery) -> None: 1369 self._args(e, 2) 1370 alias = e.args.get("alias") 1371 if alias: 1372 self.stack.append(alias) 1373 self.stack.extend((")", e.this, "(")) 1374 1375 def table_sql(self, e: exp.Table) -> None: 1376 self._args(e, 4) 1377 alias = e.args.get("alias") 1378 if alias: 1379 self.stack.append(alias) 1380 for p in reversed(e.parts): 1381 self.stack.extend((p, ".")) 1382 self.stack.pop() 1383 1384 def tablealias_sql(self, e: exp.TableAlias) -> None: 1385 columns = e.columns 1386 1387 if columns: 1388 self.stack.extend((")", columns, "(")) 1389 1390 self.stack.extend((e.this, " AS ")) 1391 1392 def var_sql(self, e: exp.Var) -> None: 1393 self.stack.append(e.this) 1394 1395 def _binary(self, e: exp.Binary, op: str) -> None: 1396 self.stack.extend((e.expression, op, e.this)) 1397 1398 def _unary(self, e: exp.Unary, op: str) -> None: 1399 self.stack.extend((e.this, op)) 1400 1401 def _function(self, e: exp.Func) -> None: 1402 self.stack.extend( 1403 ( 1404 ")", 1405 list(e.args.values()), 1406 "(", 1407 e.sql_name(), 1408 ) 1409 ) 1410 1411 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1412 kvs = [] 1413 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1414 1415 for k in arg_types or arg_types: 1416 v = node.args.get(k) 1417 1418 if v is not None: 1419 kvs.append([f":{k}", v]) 1420 if kvs: 1421 self.stack.append(kvs) 1422 return True 1423 return False
1184 def gen(self, expression: exp.Expression) -> str: 1185 self.stack = [expression] 1186 self.sqls.clear() 1187 1188 while self.stack: 1189 node = self.stack.pop() 1190 1191 if isinstance(node, exp.Expression): 1192 exp_handler_name = f"{node.key}_sql" 1193 1194 if hasattr(self, exp_handler_name): 1195 getattr(self, exp_handler_name)(node) 1196 elif isinstance(node, exp.Func): 1197 self._function(node) 1198 else: 1199 key = node.key.upper() 1200 self.stack.append(f"{key} " if self._args(node) else key) 1201 elif type(node) is list: 1202 for n in reversed(node): 1203 if n is not None: 1204 self.stack.extend((n, ",")) 1205 if node: 1206 self.stack.pop() 1207 else: 1208 if node is not None: 1209 self.sqls.append(str(node)) 1210 1211 return "".join(self.sqls)
1228 def anonymous_sql(self, e: exp.Anonymous) -> None: 1229 this = e.this 1230 if isinstance(this, str): 1231 name = this.upper() 1232 elif isinstance(this, exp.Identifier): 1233 name = this.this 1234 name = f'"{name}"' if this.quoted else name.upper() 1235 else: 1236 raise ValueError( 1237 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1238 ) 1239 1240 self.stack.extend( 1241 ( 1242 ")", 1243 e.expressions, 1244 "(", 1245 name, 1246 ) 1247 )