sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25# Value ranges for byte-sized signed/unsigned integers 26TINYINT_MIN = -128 27TINYINT_MAX = 127 28UTINYINT_MIN = 0 29UTINYINT_MAX = 255 30 31 32class UnsupportedUnit(Exception): 33 pass 34 35 36def simplify( 37 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 38): 39 """ 40 Rewrite sqlglot AST to simplify expressions. 41 42 Example: 43 >>> import sqlglot 44 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 45 >>> simplify(expression).sql() 46 'TRUE' 47 48 Args: 49 expression (sqlglot.Expression): expression to simplify 50 constant_propagation: whether the constant propagation rule should be used 51 52 Returns: 53 sqlglot.Expression: simplified expression 54 """ 55 56 dialect = Dialect.get_or_raise(dialect) 57 58 def _simplify(expression, root=True): 59 if expression.meta.get(FINAL): 60 return expression 61 62 # group by expressions cannot be simplified, for example 63 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 64 # the projection must exactly match the group by key 65 group = expression.args.get("group") 66 67 if group and hasattr(expression, "selects"): 68 groups = set(group.expressions) 69 group.meta[FINAL] = True 70 71 for e in expression.selects: 72 for node in e.walk(): 73 if node in groups: 74 e.meta[FINAL] = True 75 break 76 77 having = expression.args.get("having") 78 if having: 79 for node in having.walk(): 80 if node in groups: 81 having.meta[FINAL] = True 82 break 83 84 # Pre-order transformations 85 node = expression 86 node = rewrite_between(node) 87 node = uniq_sort(node, root) 88 node = absorb_and_eliminate(node, root) 89 node = simplify_concat(node) 90 node = simplify_conditionals(node) 91 92 if constant_propagation: 93 node = propagate_constants(node, root) 94 95 exp.replace_children(node, lambda e: _simplify(e, False)) 96 97 # Post-order transformations 98 node = simplify_not(node) 99 node = flatten(node) 100 node = simplify_connectors(node, root) 101 node = remove_complements(node, root) 102 node = simplify_coalesce(node) 103 node.parent = expression.parent 104 node = simplify_literals(node, root) 105 node = simplify_equality(node) 106 node = simplify_parens(node) 107 node = simplify_datetrunc(node, dialect) 108 node = sort_comparison(node) 109 node = simplify_startswith(node) 110 111 if root: 112 expression.replace(node) 113 return node 114 115 expression = while_changing(expression, _simplify) 116 remove_where_true(expression) 117 return expression 118 119 120def catch(*exceptions): 121 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 122 123 def decorator(func): 124 def wrapped(expression, *args, **kwargs): 125 try: 126 return func(expression, *args, **kwargs) 127 except exceptions: 128 return expression 129 130 return wrapped 131 132 return decorator 133 134 135def rewrite_between(expression: exp.Expression) -> exp.Expression: 136 """Rewrite x between y and z to x >= y AND x <= z. 137 138 This is done because comparison simplification is only done on lt/lte/gt/gte. 139 """ 140 if isinstance(expression, exp.Between): 141 negate = isinstance(expression.parent, exp.Not) 142 143 expression = exp.and_( 144 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 145 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 146 copy=False, 147 ) 148 149 if negate: 150 expression = exp.paren(expression, copy=False) 151 152 return expression 153 154 155COMPLEMENT_COMPARISONS = { 156 exp.LT: exp.GTE, 157 exp.GT: exp.LTE, 158 exp.LTE: exp.GT, 159 exp.GTE: exp.LT, 160 exp.EQ: exp.NEQ, 161 exp.NEQ: exp.EQ, 162} 163 164 165def simplify_not(expression): 166 """ 167 Demorgan's Law 168 NOT (x OR y) -> NOT x AND NOT y 169 NOT (x AND y) -> NOT x OR NOT y 170 """ 171 if isinstance(expression, exp.Not): 172 this = expression.this 173 if is_null(this): 174 return exp.null() 175 if this.__class__ in COMPLEMENT_COMPARISONS: 176 return COMPLEMENT_COMPARISONS[this.__class__]( 177 this=this.this, expression=this.expression 178 ) 179 if isinstance(this, exp.Paren): 180 condition = this.unnest() 181 if isinstance(condition, exp.And): 182 return exp.paren( 183 exp.or_( 184 exp.not_(condition.left, copy=False), 185 exp.not_(condition.right, copy=False), 186 copy=False, 187 ) 188 ) 189 if isinstance(condition, exp.Or): 190 return exp.paren( 191 exp.and_( 192 exp.not_(condition.left, copy=False), 193 exp.not_(condition.right, copy=False), 194 copy=False, 195 ) 196 ) 197 if is_null(condition): 198 return exp.null() 199 if always_true(this): 200 return exp.false() 201 if is_false(this): 202 return exp.true() 203 if isinstance(this, exp.Not): 204 # double negation 205 # NOT NOT x -> x 206 return this.this 207 return expression 208 209 210def flatten(expression): 211 """ 212 A AND (B AND C) -> A AND B AND C 213 A OR (B OR C) -> A OR B OR C 214 """ 215 if isinstance(expression, exp.Connector): 216 for node in expression.args.values(): 217 child = node.unnest() 218 if isinstance(child, expression.__class__): 219 node.replace(child) 220 return expression 221 222 223def simplify_connectors(expression, root=True): 224 def _simplify_connectors(expression, left, right): 225 if left == right: 226 return left 227 if isinstance(expression, exp.And): 228 if is_false(left) or is_false(right): 229 return exp.false() 230 if is_null(left) or is_null(right): 231 return exp.null() 232 if always_true(left) and always_true(right): 233 return exp.true() 234 if always_true(left): 235 return right 236 if always_true(right): 237 return left 238 return _simplify_comparison(expression, left, right) 239 elif isinstance(expression, exp.Or): 240 if always_true(left) or always_true(right): 241 return exp.true() 242 if is_false(left) and is_false(right): 243 return exp.false() 244 if ( 245 (is_null(left) and is_null(right)) 246 or (is_null(left) and is_false(right)) 247 or (is_false(left) and is_null(right)) 248 ): 249 return exp.null() 250 if is_false(left): 251 return right 252 if is_false(right): 253 return left 254 return _simplify_comparison(expression, left, right, or_=True) 255 256 if isinstance(expression, exp.Connector): 257 return _flat_simplify(expression, _simplify_connectors, root) 258 return expression 259 260 261LT_LTE = (exp.LT, exp.LTE) 262GT_GTE = (exp.GT, exp.GTE) 263 264COMPARISONS = ( 265 *LT_LTE, 266 *GT_GTE, 267 exp.EQ, 268 exp.NEQ, 269 exp.Is, 270) 271 272INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 273 exp.LT: exp.GT, 274 exp.GT: exp.LT, 275 exp.LTE: exp.GTE, 276 exp.GTE: exp.LTE, 277} 278 279NONDETERMINISTIC = (exp.Rand, exp.Randn) 280 281 282def _simplify_comparison(expression, left, right, or_=False): 283 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 284 ll, lr = left.args.values() 285 rl, rr = right.args.values() 286 287 largs = {ll, lr} 288 rargs = {rl, rr} 289 290 matching = largs & rargs 291 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 292 293 if matching and columns: 294 try: 295 l = first(largs - columns) 296 r = first(rargs - columns) 297 except StopIteration: 298 return expression 299 300 if l.is_number and r.is_number: 301 l = float(l.name) 302 r = float(r.name) 303 elif l.is_string and r.is_string: 304 l = l.name 305 r = r.name 306 else: 307 l = extract_date(l) 308 if not l: 309 return None 310 r = extract_date(r) 311 if not r: 312 return None 313 314 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 315 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 316 return left if (av > bv if or_ else av <= bv) else right 317 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 318 return left if (av < bv if or_ else av >= bv) else right 319 320 # we can't ever shortcut to true because the column could be null 321 if not or_: 322 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 323 if av <= bv: 324 return exp.false() 325 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 326 if av >= bv: 327 return exp.false() 328 elif isinstance(a, exp.EQ): 329 if isinstance(b, exp.LT): 330 return exp.false() if av >= bv else a 331 if isinstance(b, exp.LTE): 332 return exp.false() if av > bv else a 333 if isinstance(b, exp.GT): 334 return exp.false() if av <= bv else a 335 if isinstance(b, exp.GTE): 336 return exp.false() if av < bv else a 337 if isinstance(b, exp.NEQ): 338 return exp.false() if av == bv else a 339 return None 340 341 342def remove_complements(expression, root=True): 343 """ 344 Removing complements. 345 346 A AND NOT A -> FALSE 347 A OR NOT A -> TRUE 348 """ 349 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 350 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 351 352 for a, b in itertools.permutations(expression.flatten(), 2): 353 if is_complement(a, b): 354 return complement 355 return expression 356 357 358def uniq_sort(expression, root=True): 359 """ 360 Uniq and sort a connector. 361 362 C AND A AND B AND B -> A AND B AND C 363 """ 364 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 365 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 366 flattened = tuple(expression.flatten()) 367 deduped = {gen(e): e for e in flattened} 368 arr = tuple(deduped.items()) 369 370 # check if the operands are already sorted, if not sort them 371 # A AND C AND B -> A AND B AND C 372 for i, (sql, e) in enumerate(arr[1:]): 373 if sql < arr[i][0]: 374 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 375 break 376 else: 377 # we didn't have to sort but maybe we need to dedup 378 if len(deduped) < len(flattened): 379 expression = result_func(*deduped.values(), copy=False) 380 381 return expression 382 383 384def absorb_and_eliminate(expression, root=True): 385 """ 386 absorption: 387 A AND (A OR B) -> A 388 A OR (A AND B) -> A 389 A AND (NOT A OR B) -> A AND B 390 A OR (NOT A AND B) -> A OR B 391 elimination: 392 (A AND B) OR (A AND NOT B) -> A 393 (A OR B) AND (A OR NOT B) -> A 394 """ 395 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 396 kind = exp.Or if isinstance(expression, exp.And) else exp.And 397 398 for a, b in itertools.permutations(expression.flatten(), 2): 399 if isinstance(a, kind): 400 aa, ab = a.unnest_operands() 401 402 # absorb 403 if is_complement(b, aa): 404 aa.replace(exp.true() if kind == exp.And else exp.false()) 405 elif is_complement(b, ab): 406 ab.replace(exp.true() if kind == exp.And else exp.false()) 407 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 408 a.replace(exp.false() if kind == exp.And else exp.true()) 409 elif isinstance(b, kind): 410 # eliminate 411 rhs = b.unnest_operands() 412 ba, bb = rhs 413 414 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 415 a.replace(aa) 416 b.replace(aa) 417 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 418 a.replace(ab) 419 b.replace(ab) 420 421 return expression 422 423 424def propagate_constants(expression, root=True): 425 """ 426 Propagate constants for conjunctions in DNF: 427 428 SELECT * FROM t WHERE a = b AND b = 5 becomes 429 SELECT * FROM t WHERE a = 5 AND b = 5 430 431 Reference: https://www.sqlite.org/optoverview.html 432 """ 433 434 if ( 435 isinstance(expression, exp.And) 436 and (root or not expression.same_parent) 437 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 438 ): 439 constant_mapping = {} 440 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 441 if isinstance(expr, exp.EQ): 442 l, r = expr.left, expr.right 443 444 # TODO: create a helper that can be used to detect nested literal expressions such 445 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 446 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 447 constant_mapping[l] = (id(l), r) 448 449 if constant_mapping: 450 for column in find_all_in_scope(expression, exp.Column): 451 parent = column.parent 452 column_id, constant = constant_mapping.get(column) or (None, None) 453 if ( 454 column_id is not None 455 and id(column) != column_id 456 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 457 ): 458 column.replace(constant.copy()) 459 460 return expression 461 462 463INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 464 exp.DateAdd: exp.Sub, 465 exp.DateSub: exp.Add, 466 exp.DatetimeAdd: exp.Sub, 467 exp.DatetimeSub: exp.Add, 468} 469 470INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 471 **INVERSE_DATE_OPS, 472 exp.Add: exp.Sub, 473 exp.Sub: exp.Add, 474} 475 476 477def _is_number(expression: exp.Expression) -> bool: 478 return expression.is_number 479 480 481def _is_interval(expression: exp.Expression) -> bool: 482 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 483 484 485@catch(ModuleNotFoundError, UnsupportedUnit) 486def simplify_equality(expression: exp.Expression) -> exp.Expression: 487 """ 488 Use the subtraction and addition properties of equality to simplify expressions: 489 490 x + 1 = 3 becomes x = 2 491 492 There are two binary operations in the above expression: + and = 493 Here's how we reference all the operands in the code below: 494 495 l r 496 x + 1 = 3 497 a b 498 """ 499 if isinstance(expression, COMPARISONS): 500 l, r = expression.left, expression.right 501 502 if l.__class__ not in INVERSE_OPS: 503 return expression 504 505 if r.is_number: 506 a_predicate = _is_number 507 b_predicate = _is_number 508 elif _is_date_literal(r): 509 a_predicate = _is_date_literal 510 b_predicate = _is_interval 511 else: 512 return expression 513 514 if l.__class__ in INVERSE_DATE_OPS: 515 l = t.cast(exp.IntervalOp, l) 516 a = l.this 517 b = l.interval() 518 else: 519 l = t.cast(exp.Binary, l) 520 a, b = l.left, l.right 521 522 if not a_predicate(a) and b_predicate(b): 523 pass 524 elif not a_predicate(b) and b_predicate(a): 525 a, b = b, a 526 else: 527 return expression 528 529 return expression.__class__( 530 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 531 ) 532 return expression 533 534 535def simplify_literals(expression, root=True): 536 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 537 return _flat_simplify(expression, _simplify_binary, root) 538 539 if isinstance(expression, exp.Neg): 540 this = expression.this 541 if this.is_number: 542 value = this.name 543 if value[0] == "-": 544 return exp.Literal.number(value[1:]) 545 return exp.Literal.number(f"-{value}") 546 547 if type(expression) in INVERSE_DATE_OPS: 548 return _simplify_binary(expression, expression.this, expression.interval()) or expression 549 550 return expression 551 552 553NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 554 555 556def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 557 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 558 this = _simplify_integer_cast(expr.this) 559 else: 560 this = expr.this 561 562 if isinstance(expr, exp.Cast) and this.is_int: 563 num = int(this.name) 564 565 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 566 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 567 # engine-dependent 568 if ( 569 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 570 ) or ( 571 UTINYINT_MIN <= num <= UTINYINT_MAX 572 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 573 ): 574 return this 575 576 return expr 577 578 579def _simplify_binary(expression, a, b): 580 if isinstance(expression, COMPARISONS): 581 a = _simplify_integer_cast(a) 582 b = _simplify_integer_cast(b) 583 584 if isinstance(expression, exp.Is): 585 if isinstance(b, exp.Not): 586 c = b.this 587 not_ = True 588 else: 589 c = b 590 not_ = False 591 592 if is_null(c): 593 if isinstance(a, exp.Literal): 594 return exp.true() if not_ else exp.false() 595 if is_null(a): 596 return exp.false() if not_ else exp.true() 597 elif isinstance(expression, NULL_OK): 598 return None 599 elif is_null(a) or is_null(b): 600 return exp.null() 601 602 if a.is_number and b.is_number: 603 num_a = int(a.name) if a.is_int else Decimal(a.name) 604 num_b = int(b.name) if b.is_int else Decimal(b.name) 605 606 if isinstance(expression, exp.Add): 607 return exp.Literal.number(num_a + num_b) 608 if isinstance(expression, exp.Mul): 609 return exp.Literal.number(num_a * num_b) 610 611 # We only simplify Sub, Div if a and b have the same parent because they're not associative 612 if isinstance(expression, exp.Sub): 613 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 614 if isinstance(expression, exp.Div): 615 # engines have differing int div behavior so intdiv is not safe 616 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 617 return None 618 return exp.Literal.number(num_a / num_b) 619 620 boolean = eval_boolean(expression, num_a, num_b) 621 622 if boolean: 623 return boolean 624 elif a.is_string and b.is_string: 625 boolean = eval_boolean(expression, a.this, b.this) 626 627 if boolean: 628 return boolean 629 elif _is_date_literal(a) and isinstance(b, exp.Interval): 630 a, b = extract_date(a), extract_interval(b) 631 if a and b: 632 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 633 return date_literal(a + b) 634 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 635 return date_literal(a - b) 636 elif isinstance(a, exp.Interval) and _is_date_literal(b): 637 a, b = extract_interval(a), extract_date(b) 638 # you cannot subtract a date from an interval 639 if a and b and isinstance(expression, exp.Add): 640 return date_literal(a + b) 641 elif _is_date_literal(a) and _is_date_literal(b): 642 if isinstance(expression, exp.Predicate): 643 a, b = extract_date(a), extract_date(b) 644 boolean = eval_boolean(expression, a, b) 645 if boolean: 646 return boolean 647 648 return None 649 650 651def simplify_parens(expression): 652 if not isinstance(expression, exp.Paren): 653 return expression 654 655 this = expression.this 656 parent = expression.parent 657 parent_is_predicate = isinstance(parent, exp.Predicate) 658 659 if not isinstance(this, exp.Select) and ( 660 not isinstance(parent, (exp.Condition, exp.Binary)) 661 or isinstance(parent, exp.Paren) 662 or ( 663 not isinstance(this, exp.Binary) 664 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 665 ) 666 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 667 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 668 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 669 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 670 ): 671 return this 672 return expression 673 674 675def _is_nonnull_constant(expression: exp.Expression) -> bool: 676 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 677 678 679def _is_constant(expression: exp.Expression) -> bool: 680 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 681 682 683def simplify_coalesce(expression): 684 # COALESCE(x) -> x 685 if ( 686 isinstance(expression, exp.Coalesce) 687 and (not expression.expressions or _is_nonnull_constant(expression.this)) 688 # COALESCE is also used as a Spark partitioning hint 689 and not isinstance(expression.parent, exp.Hint) 690 ): 691 return expression.this 692 693 if not isinstance(expression, COMPARISONS): 694 return expression 695 696 if isinstance(expression.left, exp.Coalesce): 697 coalesce = expression.left 698 other = expression.right 699 elif isinstance(expression.right, exp.Coalesce): 700 coalesce = expression.right 701 other = expression.left 702 else: 703 return expression 704 705 # This transformation is valid for non-constants, 706 # but it really only does anything if they are both constants. 707 if not _is_constant(other): 708 return expression 709 710 # Find the first constant arg 711 for arg_index, arg in enumerate(coalesce.expressions): 712 if _is_constant(arg): 713 break 714 else: 715 return expression 716 717 coalesce.set("expressions", coalesce.expressions[:arg_index]) 718 719 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 720 # since we already remove COALESCE at the top of this function. 721 coalesce = coalesce if coalesce.expressions else coalesce.this 722 723 # This expression is more complex than when we started, but it will get simplified further 724 return exp.paren( 725 exp.or_( 726 exp.and_( 727 coalesce.is_(exp.null()).not_(copy=False), 728 expression.copy(), 729 copy=False, 730 ), 731 exp.and_( 732 coalesce.is_(exp.null()), 733 type(expression)(this=arg.copy(), expression=other.copy()), 734 copy=False, 735 ), 736 copy=False, 737 ) 738 ) 739 740 741CONCATS = (exp.Concat, exp.DPipe) 742 743 744def simplify_concat(expression): 745 """Reduces all groups that contain string literals by concatenating them.""" 746 if not isinstance(expression, CONCATS) or ( 747 # We can't reduce a CONCAT_WS call if we don't statically know the separator 748 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 749 ): 750 return expression 751 752 if isinstance(expression, exp.ConcatWs): 753 sep_expr, *expressions = expression.expressions 754 sep = sep_expr.name 755 concat_type = exp.ConcatWs 756 args = {} 757 else: 758 expressions = expression.expressions 759 sep = "" 760 concat_type = exp.Concat 761 args = { 762 "safe": expression.args.get("safe"), 763 "coalesce": expression.args.get("coalesce"), 764 } 765 766 new_args = [] 767 for is_string_group, group in itertools.groupby( 768 expressions or expression.flatten(), lambda e: e.is_string 769 ): 770 if is_string_group: 771 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 772 else: 773 new_args.extend(group) 774 775 if len(new_args) == 1 and new_args[0].is_string: 776 return new_args[0] 777 778 if concat_type is exp.ConcatWs: 779 new_args = [sep_expr] + new_args 780 781 return concat_type(expressions=new_args, **args) 782 783 784def simplify_conditionals(expression): 785 """Simplifies expressions like IF, CASE if their condition is statically known.""" 786 if isinstance(expression, exp.Case): 787 this = expression.this 788 for case in expression.args["ifs"]: 789 cond = case.this 790 if this: 791 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 792 cond = cond.replace(this.pop().eq(cond)) 793 794 if always_true(cond): 795 return case.args["true"] 796 797 if always_false(cond): 798 case.pop() 799 if not expression.args["ifs"]: 800 return expression.args.get("default") or exp.null() 801 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 802 if always_true(expression.this): 803 return expression.args["true"] 804 if always_false(expression.this): 805 return expression.args.get("false") or exp.null() 806 807 return expression 808 809 810def simplify_startswith(expression: exp.Expression) -> exp.Expression: 811 """ 812 Reduces a prefix check to either TRUE or FALSE if both the string and the 813 prefix are statically known. 814 815 Example: 816 >>> from sqlglot import parse_one 817 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 818 'TRUE' 819 """ 820 if ( 821 isinstance(expression, exp.StartsWith) 822 and expression.this.is_string 823 and expression.expression.is_string 824 ): 825 return exp.convert(expression.name.startswith(expression.expression.name)) 826 827 return expression 828 829 830DateRange = t.Tuple[datetime.date, datetime.date] 831 832 833def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 834 """ 835 Get the date range for a DATE_TRUNC equality comparison: 836 837 Example: 838 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 839 Returns: 840 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 841 """ 842 floor = date_floor(date, unit, dialect) 843 844 if date != floor: 845 # This will always be False, except for NULL values. 846 return None 847 848 return floor, floor + interval(unit) 849 850 851def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 852 """Get the logical expression for a date range""" 853 return exp.and_( 854 left >= date_literal(drange[0]), 855 left < date_literal(drange[1]), 856 copy=False, 857 ) 858 859 860def _datetrunc_eq( 861 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 862) -> t.Optional[exp.Expression]: 863 drange = _datetrunc_range(date, unit, dialect) 864 if not drange: 865 return None 866 867 return _datetrunc_eq_expression(left, drange) 868 869 870def _datetrunc_neq( 871 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 872) -> t.Optional[exp.Expression]: 873 drange = _datetrunc_range(date, unit, dialect) 874 if not drange: 875 return None 876 877 return exp.and_( 878 left < date_literal(drange[0]), 879 left >= date_literal(drange[1]), 880 copy=False, 881 ) 882 883 884DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 885 exp.LT: lambda l, dt, u, d: l 886 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 887 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 888 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 889 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 890 exp.EQ: _datetrunc_eq, 891 exp.NEQ: _datetrunc_neq, 892} 893DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 894DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 895 896 897def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 898 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 899 900 901@catch(ModuleNotFoundError, UnsupportedUnit) 902def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 903 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 904 comparison = expression.__class__ 905 906 if isinstance(expression, DATETRUNCS): 907 date = extract_date(expression.this) 908 if date and expression.unit: 909 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 910 elif comparison not in DATETRUNC_COMPARISONS: 911 return expression 912 913 if isinstance(expression, exp.Binary): 914 l, r = expression.left, expression.right 915 916 if not _is_datetrunc_predicate(l, r): 917 return expression 918 919 l = t.cast(exp.DateTrunc, l) 920 unit = l.unit.name.lower() 921 date = extract_date(r) 922 923 if not date: 924 return expression 925 926 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 927 elif isinstance(expression, exp.In): 928 l = expression.this 929 rs = expression.expressions 930 931 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 932 l = t.cast(exp.DateTrunc, l) 933 unit = l.unit.name.lower() 934 935 ranges = [] 936 for r in rs: 937 date = extract_date(r) 938 if not date: 939 return expression 940 drange = _datetrunc_range(date, unit, dialect) 941 if drange: 942 ranges.append(drange) 943 944 if not ranges: 945 return expression 946 947 ranges = merge_ranges(ranges) 948 949 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 950 951 return expression 952 953 954def sort_comparison(expression: exp.Expression) -> exp.Expression: 955 if expression.__class__ in COMPLEMENT_COMPARISONS: 956 l, r = expression.this, expression.expression 957 l_column = isinstance(l, exp.Column) 958 r_column = isinstance(r, exp.Column) 959 l_const = _is_constant(l) 960 r_const = _is_constant(r) 961 962 if (l_column and not r_column) or (r_const and not l_const): 963 return expression 964 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 965 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 966 this=r, expression=l 967 ) 968 return expression 969 970 971# CROSS joins result in an empty table if the right table is empty. 972# So we can only simplify certain types of joins to CROSS. 973# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 974JOINS = { 975 ("", ""), 976 ("", "INNER"), 977 ("RIGHT", ""), 978 ("RIGHT", "OUTER"), 979} 980 981 982def remove_where_true(expression): 983 for where in expression.find_all(exp.Where): 984 if always_true(where.this): 985 where.pop() 986 for join in expression.find_all(exp.Join): 987 if ( 988 always_true(join.args.get("on")) 989 and not join.args.get("using") 990 and not join.args.get("method") 991 and (join.side, join.kind) in JOINS 992 ): 993 join.args["on"].pop() 994 join.set("side", None) 995 join.set("kind", "CROSS") 996 997 998def always_true(expression): 999 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 1000 expression, exp.Literal 1001 ) 1002 1003 1004def always_false(expression): 1005 return is_false(expression) or is_null(expression) 1006 1007 1008def is_complement(a, b): 1009 return isinstance(b, exp.Not) and b.this == a 1010 1011 1012def is_false(a: exp.Expression) -> bool: 1013 return type(a) is exp.Boolean and not a.this 1014 1015 1016def is_null(a: exp.Expression) -> bool: 1017 return type(a) is exp.Null 1018 1019 1020def eval_boolean(expression, a, b): 1021 if isinstance(expression, (exp.EQ, exp.Is)): 1022 return boolean_literal(a == b) 1023 if isinstance(expression, exp.NEQ): 1024 return boolean_literal(a != b) 1025 if isinstance(expression, exp.GT): 1026 return boolean_literal(a > b) 1027 if isinstance(expression, exp.GTE): 1028 return boolean_literal(a >= b) 1029 if isinstance(expression, exp.LT): 1030 return boolean_literal(a < b) 1031 if isinstance(expression, exp.LTE): 1032 return boolean_literal(a <= b) 1033 return None 1034 1035 1036def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1037 if isinstance(value, datetime.datetime): 1038 return value.date() 1039 if isinstance(value, datetime.date): 1040 return value 1041 try: 1042 return datetime.datetime.fromisoformat(value).date() 1043 except ValueError: 1044 return None 1045 1046 1047def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1048 if isinstance(value, datetime.datetime): 1049 return value 1050 if isinstance(value, datetime.date): 1051 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1052 try: 1053 return datetime.datetime.fromisoformat(value) 1054 except ValueError: 1055 return None 1056 1057 1058def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1059 if not value: 1060 return None 1061 if to.is_type(exp.DataType.Type.DATE): 1062 return cast_as_date(value) 1063 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1064 return cast_as_datetime(value) 1065 return None 1066 1067 1068def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1069 if isinstance(cast, exp.Cast): 1070 to = cast.to 1071 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1072 to = exp.DataType.build(exp.DataType.Type.DATE) 1073 else: 1074 return None 1075 1076 if isinstance(cast.this, exp.Literal): 1077 value: t.Any = cast.this.name 1078 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1079 value = extract_date(cast.this) 1080 else: 1081 return None 1082 return cast_value(value, to) 1083 1084 1085def _is_date_literal(expression: exp.Expression) -> bool: 1086 return extract_date(expression) is not None 1087 1088 1089def extract_interval(expression): 1090 try: 1091 n = int(expression.name) 1092 unit = expression.text("unit").lower() 1093 return interval(unit, n) 1094 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1095 return None 1096 1097 1098def date_literal(date): 1099 return exp.cast( 1100 exp.Literal.string(date), 1101 ( 1102 exp.DataType.Type.DATETIME 1103 if isinstance(date, datetime.datetime) 1104 else exp.DataType.Type.DATE 1105 ), 1106 ) 1107 1108 1109def interval(unit: str, n: int = 1): 1110 from dateutil.relativedelta import relativedelta 1111 1112 if unit == "year": 1113 return relativedelta(years=1 * n) 1114 if unit == "quarter": 1115 return relativedelta(months=3 * n) 1116 if unit == "month": 1117 return relativedelta(months=1 * n) 1118 if unit == "week": 1119 return relativedelta(weeks=1 * n) 1120 if unit == "day": 1121 return relativedelta(days=1 * n) 1122 if unit == "hour": 1123 return relativedelta(hours=1 * n) 1124 if unit == "minute": 1125 return relativedelta(minutes=1 * n) 1126 if unit == "second": 1127 return relativedelta(seconds=1 * n) 1128 1129 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1130 1131 1132def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1133 if unit == "year": 1134 return d.replace(month=1, day=1) 1135 if unit == "quarter": 1136 if d.month <= 3: 1137 return d.replace(month=1, day=1) 1138 elif d.month <= 6: 1139 return d.replace(month=4, day=1) 1140 elif d.month <= 9: 1141 return d.replace(month=7, day=1) 1142 else: 1143 return d.replace(month=10, day=1) 1144 if unit == "month": 1145 return d.replace(month=d.month, day=1) 1146 if unit == "week": 1147 # Assuming week starts on Monday (0) and ends on Sunday (6) 1148 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1149 if unit == "day": 1150 return d 1151 1152 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1153 1154 1155def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1156 floor = date_floor(d, unit, dialect) 1157 1158 if floor == d: 1159 return d 1160 1161 return floor + interval(unit) 1162 1163 1164def boolean_literal(condition): 1165 return exp.true() if condition else exp.false() 1166 1167 1168def _flat_simplify(expression, simplifier, root=True): 1169 if root or not expression.same_parent: 1170 operands = [] 1171 queue = deque(expression.flatten(unnest=False)) 1172 size = len(queue) 1173 1174 while queue: 1175 a = queue.popleft() 1176 1177 for b in queue: 1178 result = simplifier(expression, a, b) 1179 1180 if result and result is not expression: 1181 queue.remove(b) 1182 queue.appendleft(result) 1183 break 1184 else: 1185 operands.append(a) 1186 1187 if len(operands) < size: 1188 return functools.reduce( 1189 lambda a, b: expression.__class__(this=a, expression=b), operands 1190 ) 1191 return expression 1192 1193 1194def gen(expression: t.Any) -> str: 1195 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1196 1197 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1198 generator is expensive so we have a bare minimum sql generator here. 1199 """ 1200 return Gen().gen(expression) 1201 1202 1203class Gen: 1204 def __init__(self): 1205 self.stack = [] 1206 self.sqls = [] 1207 1208 def gen(self, expression: exp.Expression) -> str: 1209 self.stack = [expression] 1210 self.sqls.clear() 1211 1212 while self.stack: 1213 node = self.stack.pop() 1214 1215 if isinstance(node, exp.Expression): 1216 exp_handler_name = f"{node.key}_sql" 1217 1218 if hasattr(self, exp_handler_name): 1219 getattr(self, exp_handler_name)(node) 1220 elif isinstance(node, exp.Func): 1221 self._function(node) 1222 else: 1223 key = node.key.upper() 1224 self.stack.append(f"{key} " if self._args(node) else key) 1225 elif type(node) is list: 1226 for n in reversed(node): 1227 if n is not None: 1228 self.stack.extend((n, ",")) 1229 if node: 1230 self.stack.pop() 1231 else: 1232 if node is not None: 1233 self.sqls.append(str(node)) 1234 1235 return "".join(self.sqls) 1236 1237 def add_sql(self, e: exp.Add) -> None: 1238 self._binary(e, " + ") 1239 1240 def alias_sql(self, e: exp.Alias) -> None: 1241 self.stack.extend( 1242 ( 1243 e.args.get("alias"), 1244 " AS ", 1245 e.args.get("this"), 1246 ) 1247 ) 1248 1249 def and_sql(self, e: exp.And) -> None: 1250 self._binary(e, " AND ") 1251 1252 def anonymous_sql(self, e: exp.Anonymous) -> None: 1253 this = e.this 1254 if isinstance(this, str): 1255 name = this.upper() 1256 elif isinstance(this, exp.Identifier): 1257 name = this.this 1258 name = f'"{name}"' if this.quoted else name.upper() 1259 else: 1260 raise ValueError( 1261 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1262 ) 1263 1264 self.stack.extend( 1265 ( 1266 ")", 1267 e.expressions, 1268 "(", 1269 name, 1270 ) 1271 ) 1272 1273 def between_sql(self, e: exp.Between) -> None: 1274 self.stack.extend( 1275 ( 1276 e.args.get("high"), 1277 " AND ", 1278 e.args.get("low"), 1279 " BETWEEN ", 1280 e.this, 1281 ) 1282 ) 1283 1284 def boolean_sql(self, e: exp.Boolean) -> None: 1285 self.stack.append("TRUE" if e.this else "FALSE") 1286 1287 def bracket_sql(self, e: exp.Bracket) -> None: 1288 self.stack.extend( 1289 ( 1290 "]", 1291 e.expressions, 1292 "[", 1293 e.this, 1294 ) 1295 ) 1296 1297 def column_sql(self, e: exp.Column) -> None: 1298 for p in reversed(e.parts): 1299 self.stack.extend((p, ".")) 1300 self.stack.pop() 1301 1302 def datatype_sql(self, e: exp.DataType) -> None: 1303 self._args(e, 1) 1304 self.stack.append(f"{e.this.name} ") 1305 1306 def div_sql(self, e: exp.Div) -> None: 1307 self._binary(e, " / ") 1308 1309 def dot_sql(self, e: exp.Dot) -> None: 1310 self._binary(e, ".") 1311 1312 def eq_sql(self, e: exp.EQ) -> None: 1313 self._binary(e, " = ") 1314 1315 def from_sql(self, e: exp.From) -> None: 1316 self.stack.extend((e.this, "FROM ")) 1317 1318 def gt_sql(self, e: exp.GT) -> None: 1319 self._binary(e, " > ") 1320 1321 def gte_sql(self, e: exp.GTE) -> None: 1322 self._binary(e, " >= ") 1323 1324 def identifier_sql(self, e: exp.Identifier) -> None: 1325 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1326 1327 def ilike_sql(self, e: exp.ILike) -> None: 1328 self._binary(e, " ILIKE ") 1329 1330 def in_sql(self, e: exp.In) -> None: 1331 self.stack.append(")") 1332 self._args(e, 1) 1333 self.stack.extend( 1334 ( 1335 "(", 1336 " IN ", 1337 e.this, 1338 ) 1339 ) 1340 1341 def intdiv_sql(self, e: exp.IntDiv) -> None: 1342 self._binary(e, " DIV ") 1343 1344 def is_sql(self, e: exp.Is) -> None: 1345 self._binary(e, " IS ") 1346 1347 def like_sql(self, e: exp.Like) -> None: 1348 self._binary(e, " Like ") 1349 1350 def literal_sql(self, e: exp.Literal) -> None: 1351 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1352 1353 def lt_sql(self, e: exp.LT) -> None: 1354 self._binary(e, " < ") 1355 1356 def lte_sql(self, e: exp.LTE) -> None: 1357 self._binary(e, " <= ") 1358 1359 def mod_sql(self, e: exp.Mod) -> None: 1360 self._binary(e, " % ") 1361 1362 def mul_sql(self, e: exp.Mul) -> None: 1363 self._binary(e, " * ") 1364 1365 def neg_sql(self, e: exp.Neg) -> None: 1366 self._unary(e, "-") 1367 1368 def neq_sql(self, e: exp.NEQ) -> None: 1369 self._binary(e, " <> ") 1370 1371 def not_sql(self, e: exp.Not) -> None: 1372 self._unary(e, "NOT ") 1373 1374 def null_sql(self, e: exp.Null) -> None: 1375 self.stack.append("NULL") 1376 1377 def or_sql(self, e: exp.Or) -> None: 1378 self._binary(e, " OR ") 1379 1380 def paren_sql(self, e: exp.Paren) -> None: 1381 self.stack.extend( 1382 ( 1383 ")", 1384 e.this, 1385 "(", 1386 ) 1387 ) 1388 1389 def sub_sql(self, e: exp.Sub) -> None: 1390 self._binary(e, " - ") 1391 1392 def subquery_sql(self, e: exp.Subquery) -> None: 1393 self._args(e, 2) 1394 alias = e.args.get("alias") 1395 if alias: 1396 self.stack.append(alias) 1397 self.stack.extend((")", e.this, "(")) 1398 1399 def table_sql(self, e: exp.Table) -> None: 1400 self._args(e, 4) 1401 alias = e.args.get("alias") 1402 if alias: 1403 self.stack.append(alias) 1404 for p in reversed(e.parts): 1405 self.stack.extend((p, ".")) 1406 self.stack.pop() 1407 1408 def tablealias_sql(self, e: exp.TableAlias) -> None: 1409 columns = e.columns 1410 1411 if columns: 1412 self.stack.extend((")", columns, "(")) 1413 1414 self.stack.extend((e.this, " AS ")) 1415 1416 def var_sql(self, e: exp.Var) -> None: 1417 self.stack.append(e.this) 1418 1419 def _binary(self, e: exp.Binary, op: str) -> None: 1420 self.stack.extend((e.expression, op, e.this)) 1421 1422 def _unary(self, e: exp.Unary, op: str) -> None: 1423 self.stack.extend((e.this, op)) 1424 1425 def _function(self, e: exp.Func) -> None: 1426 self.stack.extend( 1427 ( 1428 ")", 1429 list(e.args.values()), 1430 "(", 1431 e.sql_name(), 1432 ) 1433 ) 1434 1435 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1436 kvs = [] 1437 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1438 1439 for k in arg_types or arg_types: 1440 v = node.args.get(k) 1441 1442 if v is not None: 1443 kvs.append([f":{k}", v]) 1444 if kvs: 1445 self.stack.append(kvs) 1446 return True 1447 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
37def simplify( 38 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 39): 40 """ 41 Rewrite sqlglot AST to simplify expressions. 42 43 Example: 44 >>> import sqlglot 45 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 46 >>> simplify(expression).sql() 47 'TRUE' 48 49 Args: 50 expression (sqlglot.Expression): expression to simplify 51 constant_propagation: whether the constant propagation rule should be used 52 53 Returns: 54 sqlglot.Expression: simplified expression 55 """ 56 57 dialect = Dialect.get_or_raise(dialect) 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # group by expressions cannot be simplified, for example 64 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 65 # the projection must exactly match the group by key 66 group = expression.args.get("group") 67 68 if group and hasattr(expression, "selects"): 69 groups = set(group.expressions) 70 group.meta[FINAL] = True 71 72 for e in expression.selects: 73 for node in e.walk(): 74 if node in groups: 75 e.meta[FINAL] = True 76 break 77 78 having = expression.args.get("having") 79 if having: 80 for node in having.walk(): 81 if node in groups: 82 having.meta[FINAL] = True 83 break 84 85 # Pre-order transformations 86 node = expression 87 node = rewrite_between(node) 88 node = uniq_sort(node, root) 89 node = absorb_and_eliminate(node, root) 90 node = simplify_concat(node) 91 node = simplify_conditionals(node) 92 93 if constant_propagation: 94 node = propagate_constants(node, root) 95 96 exp.replace_children(node, lambda e: _simplify(e, False)) 97 98 # Post-order transformations 99 node = simplify_not(node) 100 node = flatten(node) 101 node = simplify_connectors(node, root) 102 node = remove_complements(node, root) 103 node = simplify_coalesce(node) 104 node.parent = expression.parent 105 node = simplify_literals(node, root) 106 node = simplify_equality(node) 107 node = simplify_parens(node) 108 node = simplify_datetrunc(node, dialect) 109 node = sort_comparison(node) 110 node = simplify_startswith(node) 111 112 if root: 113 expression.replace(node) 114 return node 115 116 expression = while_changing(expression, _simplify) 117 remove_where_true(expression) 118 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
121def catch(*exceptions): 122 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 123 124 def decorator(func): 125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression 130 131 return wrapped 132 133 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
136def rewrite_between(expression: exp.Expression) -> exp.Expression: 137 """Rewrite x between y and z to x >= y AND x <= z. 138 139 This is done because comparison simplification is only done on lt/lte/gt/gte. 140 """ 141 if isinstance(expression, exp.Between): 142 negate = isinstance(expression.parent, exp.Not) 143 144 expression = exp.and_( 145 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 146 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 147 copy=False, 148 ) 149 150 if negate: 151 expression = exp.paren(expression, copy=False) 152 153 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
166def simplify_not(expression): 167 """ 168 Demorgan's Law 169 NOT (x OR y) -> NOT x AND NOT y 170 NOT (x AND y) -> NOT x OR NOT y 171 """ 172 if isinstance(expression, exp.Not): 173 this = expression.this 174 if is_null(this): 175 return exp.null() 176 if this.__class__ in COMPLEMENT_COMPARISONS: 177 return COMPLEMENT_COMPARISONS[this.__class__]( 178 this=this.this, expression=this.expression 179 ) 180 if isinstance(this, exp.Paren): 181 condition = this.unnest() 182 if isinstance(condition, exp.And): 183 return exp.paren( 184 exp.or_( 185 exp.not_(condition.left, copy=False), 186 exp.not_(condition.right, copy=False), 187 copy=False, 188 ) 189 ) 190 if isinstance(condition, exp.Or): 191 return exp.paren( 192 exp.and_( 193 exp.not_(condition.left, copy=False), 194 exp.not_(condition.right, copy=False), 195 copy=False, 196 ) 197 ) 198 if is_null(condition): 199 return exp.null() 200 if always_true(this): 201 return exp.false() 202 if is_false(this): 203 return exp.true() 204 if isinstance(this, exp.Not): 205 # double negation 206 # NOT NOT x -> x 207 return this.this 208 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
211def flatten(expression): 212 """ 213 A AND (B AND C) -> A AND B AND C 214 A OR (B OR C) -> A OR B OR C 215 """ 216 if isinstance(expression, exp.Connector): 217 for node in expression.args.values(): 218 child = node.unnest() 219 if isinstance(child, expression.__class__): 220 node.replace(child) 221 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
224def simplify_connectors(expression, root=True): 225 def _simplify_connectors(expression, left, right): 226 if left == right: 227 return left 228 if isinstance(expression, exp.And): 229 if is_false(left) or is_false(right): 230 return exp.false() 231 if is_null(left) or is_null(right): 232 return exp.null() 233 if always_true(left) and always_true(right): 234 return exp.true() 235 if always_true(left): 236 return right 237 if always_true(right): 238 return left 239 return _simplify_comparison(expression, left, right) 240 elif isinstance(expression, exp.Or): 241 if always_true(left) or always_true(right): 242 return exp.true() 243 if is_false(left) and is_false(right): 244 return exp.false() 245 if ( 246 (is_null(left) and is_null(right)) 247 or (is_null(left) and is_false(right)) 248 or (is_false(left) and is_null(right)) 249 ): 250 return exp.null() 251 if is_false(left): 252 return right 253 if is_false(right): 254 return left 255 return _simplify_comparison(expression, left, right, or_=True) 256 257 if isinstance(expression, exp.Connector): 258 return _flat_simplify(expression, _simplify_connectors, root) 259 return expression
343def remove_complements(expression, root=True): 344 """ 345 Removing complements. 346 347 A AND NOT A -> FALSE 348 A OR NOT A -> TRUE 349 """ 350 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 351 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 352 353 for a, b in itertools.permutations(expression.flatten(), 2): 354 if is_complement(a, b): 355 return complement 356 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
359def uniq_sort(expression, root=True): 360 """ 361 Uniq and sort a connector. 362 363 C AND A AND B AND B -> A AND B AND C 364 """ 365 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 366 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 367 flattened = tuple(expression.flatten()) 368 deduped = {gen(e): e for e in flattened} 369 arr = tuple(deduped.items()) 370 371 # check if the operands are already sorted, if not sort them 372 # A AND C AND B -> A AND B AND C 373 for i, (sql, e) in enumerate(arr[1:]): 374 if sql < arr[i][0]: 375 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 376 break 377 else: 378 # we didn't have to sort but maybe we need to dedup 379 if len(deduped) < len(flattened): 380 expression = result_func(*deduped.values(), copy=False) 381 382 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
385def absorb_and_eliminate(expression, root=True): 386 """ 387 absorption: 388 A AND (A OR B) -> A 389 A OR (A AND B) -> A 390 A AND (NOT A OR B) -> A AND B 391 A OR (NOT A AND B) -> A OR B 392 elimination: 393 (A AND B) OR (A AND NOT B) -> A 394 (A OR B) AND (A OR NOT B) -> A 395 """ 396 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 397 kind = exp.Or if isinstance(expression, exp.And) else exp.And 398 399 for a, b in itertools.permutations(expression.flatten(), 2): 400 if isinstance(a, kind): 401 aa, ab = a.unnest_operands() 402 403 # absorb 404 if is_complement(b, aa): 405 aa.replace(exp.true() if kind == exp.And else exp.false()) 406 elif is_complement(b, ab): 407 ab.replace(exp.true() if kind == exp.And else exp.false()) 408 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 409 a.replace(exp.false() if kind == exp.And else exp.true()) 410 elif isinstance(b, kind): 411 # eliminate 412 rhs = b.unnest_operands() 413 ba, bb = rhs 414 415 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 416 a.replace(aa) 417 b.replace(aa) 418 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 419 a.replace(ab) 420 b.replace(ab) 421 422 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
425def propagate_constants(expression, root=True): 426 """ 427 Propagate constants for conjunctions in DNF: 428 429 SELECT * FROM t WHERE a = b AND b = 5 becomes 430 SELECT * FROM t WHERE a = 5 AND b = 5 431 432 Reference: https://www.sqlite.org/optoverview.html 433 """ 434 435 if ( 436 isinstance(expression, exp.And) 437 and (root or not expression.same_parent) 438 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 439 ): 440 constant_mapping = {} 441 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 442 if isinstance(expr, exp.EQ): 443 l, r = expr.left, expr.right 444 445 # TODO: create a helper that can be used to detect nested literal expressions such 446 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 447 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 448 constant_mapping[l] = (id(l), r) 449 450 if constant_mapping: 451 for column in find_all_in_scope(expression, exp.Column): 452 parent = column.parent 453 column_id, constant = constant_mapping.get(column) or (None, None) 454 if ( 455 column_id is not None 456 and id(column) != column_id 457 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 458 ): 459 column.replace(constant.copy()) 460 461 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
536def simplify_literals(expression, root=True): 537 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 538 return _flat_simplify(expression, _simplify_binary, root) 539 540 if isinstance(expression, exp.Neg): 541 this = expression.this 542 if this.is_number: 543 value = this.name 544 if value[0] == "-": 545 return exp.Literal.number(value[1:]) 546 return exp.Literal.number(f"-{value}") 547 548 if type(expression) in INVERSE_DATE_OPS: 549 return _simplify_binary(expression, expression.this, expression.interval()) or expression 550 551 return expression
652def simplify_parens(expression): 653 if not isinstance(expression, exp.Paren): 654 return expression 655 656 this = expression.this 657 parent = expression.parent 658 parent_is_predicate = isinstance(parent, exp.Predicate) 659 660 if not isinstance(this, exp.Select) and ( 661 not isinstance(parent, (exp.Condition, exp.Binary)) 662 or isinstance(parent, exp.Paren) 663 or ( 664 not isinstance(this, exp.Binary) 665 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 666 ) 667 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 668 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 669 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 670 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 671 ): 672 return this 673 return expression
684def simplify_coalesce(expression): 685 # COALESCE(x) -> x 686 if ( 687 isinstance(expression, exp.Coalesce) 688 and (not expression.expressions or _is_nonnull_constant(expression.this)) 689 # COALESCE is also used as a Spark partitioning hint 690 and not isinstance(expression.parent, exp.Hint) 691 ): 692 return expression.this 693 694 if not isinstance(expression, COMPARISONS): 695 return expression 696 697 if isinstance(expression.left, exp.Coalesce): 698 coalesce = expression.left 699 other = expression.right 700 elif isinstance(expression.right, exp.Coalesce): 701 coalesce = expression.right 702 other = expression.left 703 else: 704 return expression 705 706 # This transformation is valid for non-constants, 707 # but it really only does anything if they are both constants. 708 if not _is_constant(other): 709 return expression 710 711 # Find the first constant arg 712 for arg_index, arg in enumerate(coalesce.expressions): 713 if _is_constant(arg): 714 break 715 else: 716 return expression 717 718 coalesce.set("expressions", coalesce.expressions[:arg_index]) 719 720 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 721 # since we already remove COALESCE at the top of this function. 722 coalesce = coalesce if coalesce.expressions else coalesce.this 723 724 # This expression is more complex than when we started, but it will get simplified further 725 return exp.paren( 726 exp.or_( 727 exp.and_( 728 coalesce.is_(exp.null()).not_(copy=False), 729 expression.copy(), 730 copy=False, 731 ), 732 exp.and_( 733 coalesce.is_(exp.null()), 734 type(expression)(this=arg.copy(), expression=other.copy()), 735 copy=False, 736 ), 737 copy=False, 738 ) 739 )
745def simplify_concat(expression): 746 """Reduces all groups that contain string literals by concatenating them.""" 747 if not isinstance(expression, CONCATS) or ( 748 # We can't reduce a CONCAT_WS call if we don't statically know the separator 749 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 750 ): 751 return expression 752 753 if isinstance(expression, exp.ConcatWs): 754 sep_expr, *expressions = expression.expressions 755 sep = sep_expr.name 756 concat_type = exp.ConcatWs 757 args = {} 758 else: 759 expressions = expression.expressions 760 sep = "" 761 concat_type = exp.Concat 762 args = { 763 "safe": expression.args.get("safe"), 764 "coalesce": expression.args.get("coalesce"), 765 } 766 767 new_args = [] 768 for is_string_group, group in itertools.groupby( 769 expressions or expression.flatten(), lambda e: e.is_string 770 ): 771 if is_string_group: 772 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 773 else: 774 new_args.extend(group) 775 776 if len(new_args) == 1 and new_args[0].is_string: 777 return new_args[0] 778 779 if concat_type is exp.ConcatWs: 780 new_args = [sep_expr] + new_args 781 782 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
785def simplify_conditionals(expression): 786 """Simplifies expressions like IF, CASE if their condition is statically known.""" 787 if isinstance(expression, exp.Case): 788 this = expression.this 789 for case in expression.args["ifs"]: 790 cond = case.this 791 if this: 792 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 793 cond = cond.replace(this.pop().eq(cond)) 794 795 if always_true(cond): 796 return case.args["true"] 797 798 if always_false(cond): 799 case.pop() 800 if not expression.args["ifs"]: 801 return expression.args.get("default") or exp.null() 802 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 803 if always_true(expression.this): 804 return expression.args["true"] 805 if always_false(expression.this): 806 return expression.args.get("false") or exp.null() 807 808 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
811def simplify_startswith(expression: exp.Expression) -> exp.Expression: 812 """ 813 Reduces a prefix check to either TRUE or FALSE if both the string and the 814 prefix are statically known. 815 816 Example: 817 >>> from sqlglot import parse_one 818 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 819 'TRUE' 820 """ 821 if ( 822 isinstance(expression, exp.StartsWith) 823 and expression.this.is_string 824 and expression.expression.is_string 825 ): 826 return exp.convert(expression.name.startswith(expression.expression.name)) 827 828 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
955def sort_comparison(expression: exp.Expression) -> exp.Expression: 956 if expression.__class__ in COMPLEMENT_COMPARISONS: 957 l, r = expression.this, expression.expression 958 l_column = isinstance(l, exp.Column) 959 r_column = isinstance(r, exp.Column) 960 l_const = _is_constant(l) 961 r_const = _is_constant(r) 962 963 if (l_column and not r_column) or (r_const and not l_const): 964 return expression 965 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 966 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 967 this=r, expression=l 968 ) 969 return expression
983def remove_where_true(expression): 984 for where in expression.find_all(exp.Where): 985 if always_true(where.this): 986 where.pop() 987 for join in expression.find_all(exp.Join): 988 if ( 989 always_true(join.args.get("on")) 990 and not join.args.get("using") 991 and not join.args.get("method") 992 and (join.side, join.kind) in JOINS 993 ): 994 join.args["on"].pop() 995 join.set("side", None) 996 join.set("kind", "CROSS")
1021def eval_boolean(expression, a, b): 1022 if isinstance(expression, (exp.EQ, exp.Is)): 1023 return boolean_literal(a == b) 1024 if isinstance(expression, exp.NEQ): 1025 return boolean_literal(a != b) 1026 if isinstance(expression, exp.GT): 1027 return boolean_literal(a > b) 1028 if isinstance(expression, exp.GTE): 1029 return boolean_literal(a >= b) 1030 if isinstance(expression, exp.LT): 1031 return boolean_literal(a < b) 1032 if isinstance(expression, exp.LTE): 1033 return boolean_literal(a <= b) 1034 return None
1037def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1038 if isinstance(value, datetime.datetime): 1039 return value.date() 1040 if isinstance(value, datetime.date): 1041 return value 1042 try: 1043 return datetime.datetime.fromisoformat(value).date() 1044 except ValueError: 1045 return None
1048def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1049 if isinstance(value, datetime.datetime): 1050 return value 1051 if isinstance(value, datetime.date): 1052 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1053 try: 1054 return datetime.datetime.fromisoformat(value) 1055 except ValueError: 1056 return None
1059def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1060 if not value: 1061 return None 1062 if to.is_type(exp.DataType.Type.DATE): 1063 return cast_as_date(value) 1064 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1065 return cast_as_datetime(value) 1066 return None
1069def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1070 if isinstance(cast, exp.Cast): 1071 to = cast.to 1072 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1073 to = exp.DataType.build(exp.DataType.Type.DATE) 1074 else: 1075 return None 1076 1077 if isinstance(cast.this, exp.Literal): 1078 value: t.Any = cast.this.name 1079 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1080 value = extract_date(cast.this) 1081 else: 1082 return None 1083 return cast_value(value, to)
1110def interval(unit: str, n: int = 1): 1111 from dateutil.relativedelta import relativedelta 1112 1113 if unit == "year": 1114 return relativedelta(years=1 * n) 1115 if unit == "quarter": 1116 return relativedelta(months=3 * n) 1117 if unit == "month": 1118 return relativedelta(months=1 * n) 1119 if unit == "week": 1120 return relativedelta(weeks=1 * n) 1121 if unit == "day": 1122 return relativedelta(days=1 * n) 1123 if unit == "hour": 1124 return relativedelta(hours=1 * n) 1125 if unit == "minute": 1126 return relativedelta(minutes=1 * n) 1127 if unit == "second": 1128 return relativedelta(seconds=1 * n) 1129 1130 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1133def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1134 if unit == "year": 1135 return d.replace(month=1, day=1) 1136 if unit == "quarter": 1137 if d.month <= 3: 1138 return d.replace(month=1, day=1) 1139 elif d.month <= 6: 1140 return d.replace(month=4, day=1) 1141 elif d.month <= 9: 1142 return d.replace(month=7, day=1) 1143 else: 1144 return d.replace(month=10, day=1) 1145 if unit == "month": 1146 return d.replace(month=d.month, day=1) 1147 if unit == "week": 1148 # Assuming week starts on Monday (0) and ends on Sunday (6) 1149 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1150 if unit == "day": 1151 return d 1152 1153 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1195def gen(expression: t.Any) -> str: 1196 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1197 1198 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1199 generator is expensive so we have a bare minimum sql generator here. 1200 """ 1201 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1204class Gen: 1205 def __init__(self): 1206 self.stack = [] 1207 self.sqls = [] 1208 1209 def gen(self, expression: exp.Expression) -> str: 1210 self.stack = [expression] 1211 self.sqls.clear() 1212 1213 while self.stack: 1214 node = self.stack.pop() 1215 1216 if isinstance(node, exp.Expression): 1217 exp_handler_name = f"{node.key}_sql" 1218 1219 if hasattr(self, exp_handler_name): 1220 getattr(self, exp_handler_name)(node) 1221 elif isinstance(node, exp.Func): 1222 self._function(node) 1223 else: 1224 key = node.key.upper() 1225 self.stack.append(f"{key} " if self._args(node) else key) 1226 elif type(node) is list: 1227 for n in reversed(node): 1228 if n is not None: 1229 self.stack.extend((n, ",")) 1230 if node: 1231 self.stack.pop() 1232 else: 1233 if node is not None: 1234 self.sqls.append(str(node)) 1235 1236 return "".join(self.sqls) 1237 1238 def add_sql(self, e: exp.Add) -> None: 1239 self._binary(e, " + ") 1240 1241 def alias_sql(self, e: exp.Alias) -> None: 1242 self.stack.extend( 1243 ( 1244 e.args.get("alias"), 1245 " AS ", 1246 e.args.get("this"), 1247 ) 1248 ) 1249 1250 def and_sql(self, e: exp.And) -> None: 1251 self._binary(e, " AND ") 1252 1253 def anonymous_sql(self, e: exp.Anonymous) -> None: 1254 this = e.this 1255 if isinstance(this, str): 1256 name = this.upper() 1257 elif isinstance(this, exp.Identifier): 1258 name = this.this 1259 name = f'"{name}"' if this.quoted else name.upper() 1260 else: 1261 raise ValueError( 1262 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1263 ) 1264 1265 self.stack.extend( 1266 ( 1267 ")", 1268 e.expressions, 1269 "(", 1270 name, 1271 ) 1272 ) 1273 1274 def between_sql(self, e: exp.Between) -> None: 1275 self.stack.extend( 1276 ( 1277 e.args.get("high"), 1278 " AND ", 1279 e.args.get("low"), 1280 " BETWEEN ", 1281 e.this, 1282 ) 1283 ) 1284 1285 def boolean_sql(self, e: exp.Boolean) -> None: 1286 self.stack.append("TRUE" if e.this else "FALSE") 1287 1288 def bracket_sql(self, e: exp.Bracket) -> None: 1289 self.stack.extend( 1290 ( 1291 "]", 1292 e.expressions, 1293 "[", 1294 e.this, 1295 ) 1296 ) 1297 1298 def column_sql(self, e: exp.Column) -> None: 1299 for p in reversed(e.parts): 1300 self.stack.extend((p, ".")) 1301 self.stack.pop() 1302 1303 def datatype_sql(self, e: exp.DataType) -> None: 1304 self._args(e, 1) 1305 self.stack.append(f"{e.this.name} ") 1306 1307 def div_sql(self, e: exp.Div) -> None: 1308 self._binary(e, " / ") 1309 1310 def dot_sql(self, e: exp.Dot) -> None: 1311 self._binary(e, ".") 1312 1313 def eq_sql(self, e: exp.EQ) -> None: 1314 self._binary(e, " = ") 1315 1316 def from_sql(self, e: exp.From) -> None: 1317 self.stack.extend((e.this, "FROM ")) 1318 1319 def gt_sql(self, e: exp.GT) -> None: 1320 self._binary(e, " > ") 1321 1322 def gte_sql(self, e: exp.GTE) -> None: 1323 self._binary(e, " >= ") 1324 1325 def identifier_sql(self, e: exp.Identifier) -> None: 1326 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1327 1328 def ilike_sql(self, e: exp.ILike) -> None: 1329 self._binary(e, " ILIKE ") 1330 1331 def in_sql(self, e: exp.In) -> None: 1332 self.stack.append(")") 1333 self._args(e, 1) 1334 self.stack.extend( 1335 ( 1336 "(", 1337 " IN ", 1338 e.this, 1339 ) 1340 ) 1341 1342 def intdiv_sql(self, e: exp.IntDiv) -> None: 1343 self._binary(e, " DIV ") 1344 1345 def is_sql(self, e: exp.Is) -> None: 1346 self._binary(e, " IS ") 1347 1348 def like_sql(self, e: exp.Like) -> None: 1349 self._binary(e, " Like ") 1350 1351 def literal_sql(self, e: exp.Literal) -> None: 1352 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1353 1354 def lt_sql(self, e: exp.LT) -> None: 1355 self._binary(e, " < ") 1356 1357 def lte_sql(self, e: exp.LTE) -> None: 1358 self._binary(e, " <= ") 1359 1360 def mod_sql(self, e: exp.Mod) -> None: 1361 self._binary(e, " % ") 1362 1363 def mul_sql(self, e: exp.Mul) -> None: 1364 self._binary(e, " * ") 1365 1366 def neg_sql(self, e: exp.Neg) -> None: 1367 self._unary(e, "-") 1368 1369 def neq_sql(self, e: exp.NEQ) -> None: 1370 self._binary(e, " <> ") 1371 1372 def not_sql(self, e: exp.Not) -> None: 1373 self._unary(e, "NOT ") 1374 1375 def null_sql(self, e: exp.Null) -> None: 1376 self.stack.append("NULL") 1377 1378 def or_sql(self, e: exp.Or) -> None: 1379 self._binary(e, " OR ") 1380 1381 def paren_sql(self, e: exp.Paren) -> None: 1382 self.stack.extend( 1383 ( 1384 ")", 1385 e.this, 1386 "(", 1387 ) 1388 ) 1389 1390 def sub_sql(self, e: exp.Sub) -> None: 1391 self._binary(e, " - ") 1392 1393 def subquery_sql(self, e: exp.Subquery) -> None: 1394 self._args(e, 2) 1395 alias = e.args.get("alias") 1396 if alias: 1397 self.stack.append(alias) 1398 self.stack.extend((")", e.this, "(")) 1399 1400 def table_sql(self, e: exp.Table) -> None: 1401 self._args(e, 4) 1402 alias = e.args.get("alias") 1403 if alias: 1404 self.stack.append(alias) 1405 for p in reversed(e.parts): 1406 self.stack.extend((p, ".")) 1407 self.stack.pop() 1408 1409 def tablealias_sql(self, e: exp.TableAlias) -> None: 1410 columns = e.columns 1411 1412 if columns: 1413 self.stack.extend((")", columns, "(")) 1414 1415 self.stack.extend((e.this, " AS ")) 1416 1417 def var_sql(self, e: exp.Var) -> None: 1418 self.stack.append(e.this) 1419 1420 def _binary(self, e: exp.Binary, op: str) -> None: 1421 self.stack.extend((e.expression, op, e.this)) 1422 1423 def _unary(self, e: exp.Unary, op: str) -> None: 1424 self.stack.extend((e.this, op)) 1425 1426 def _function(self, e: exp.Func) -> None: 1427 self.stack.extend( 1428 ( 1429 ")", 1430 list(e.args.values()), 1431 "(", 1432 e.sql_name(), 1433 ) 1434 ) 1435 1436 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1437 kvs = [] 1438 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1439 1440 for k in arg_types or arg_types: 1441 v = node.args.get(k) 1442 1443 if v is not None: 1444 kvs.append([f":{k}", v]) 1445 if kvs: 1446 self.stack.append(kvs) 1447 return True 1448 return False
1209 def gen(self, expression: exp.Expression) -> str: 1210 self.stack = [expression] 1211 self.sqls.clear() 1212 1213 while self.stack: 1214 node = self.stack.pop() 1215 1216 if isinstance(node, exp.Expression): 1217 exp_handler_name = f"{node.key}_sql" 1218 1219 if hasattr(self, exp_handler_name): 1220 getattr(self, exp_handler_name)(node) 1221 elif isinstance(node, exp.Func): 1222 self._function(node) 1223 else: 1224 key = node.key.upper() 1225 self.stack.append(f"{key} " if self._args(node) else key) 1226 elif type(node) is list: 1227 for n in reversed(node): 1228 if n is not None: 1229 self.stack.extend((n, ",")) 1230 if node: 1231 self.stack.pop() 1232 else: 1233 if node is not None: 1234 self.sqls.append(str(node)) 1235 1236 return "".join(self.sqls)
1253 def anonymous_sql(self, e: exp.Anonymous) -> None: 1254 this = e.this 1255 if isinstance(this, str): 1256 name = this.upper() 1257 elif isinstance(this, exp.Identifier): 1258 name = this.this 1259 name = f'"{name}"' if this.quoted else name.upper() 1260 else: 1261 raise ValueError( 1262 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1263 ) 1264 1265 self.stack.extend( 1266 ( 1267 ")", 1268 e.expressions, 1269 "(", 1270 name, 1271 ) 1272 )