sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8import sqlglot 9from sqlglot import exp 10from sqlglot.generator import cached_generator 11from sqlglot.helper import first, merge_ranges, while_changing 12from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 13 14# Final means that an expression should not be simplified 15FINAL = "final" 16 17 18class UnsupportedUnit(Exception): 19 pass 20 21 22def simplify(expression, constant_propagation=False): 23 """ 24 Rewrite sqlglot AST to simplify expressions. 25 26 Example: 27 >>> import sqlglot 28 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 29 >>> simplify(expression).sql() 30 'TRUE' 31 32 Args: 33 expression (sqlglot.Expression): expression to simplify 34 constant_propagation: whether or not the constant propagation rule should be used 35 36 Returns: 37 sqlglot.Expression: simplified expression 38 """ 39 40 generate = cached_generator() 41 42 # group by expressions cannot be simplified, for example 43 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 44 # the projection must exactly match the group by key 45 for group in expression.find_all(exp.Group): 46 select = group.parent 47 groups = set(group.expressions) 48 group.meta[FINAL] = True 49 50 for e in select.selects: 51 for node, *_ in e.walk(): 52 if node in groups: 53 e.meta[FINAL] = True 54 break 55 56 having = select.args.get("having") 57 if having: 58 for node, *_ in having.walk(): 59 if node in groups: 60 having.meta[FINAL] = True 61 break 62 63 def _simplify(expression, root=True): 64 if expression.meta.get(FINAL): 65 return expression 66 67 # Pre-order transformations 68 node = expression 69 node = rewrite_between(node) 70 node = uniq_sort(node, generate, root) 71 node = absorb_and_eliminate(node, root) 72 node = simplify_concat(node) 73 74 if constant_propagation: 75 node = propagate_constants(node, root) 76 77 exp.replace_children(node, lambda e: _simplify(e, False)) 78 79 # Post-order transformations 80 node = simplify_not(node) 81 node = flatten(node) 82 node = simplify_connectors(node, root) 83 node = remove_complements(node, root) 84 node = simplify_coalesce(node) 85 node.parent = expression.parent 86 node = simplify_literals(node, root) 87 node = simplify_equality(node) 88 node = simplify_parens(node) 89 node = simplify_datetrunc_predicate(node) 90 91 if root: 92 expression.replace(node) 93 94 return node 95 96 expression = while_changing(expression, _simplify) 97 remove_where_true(expression) 98 return expression 99 100 101def catch(*exceptions): 102 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 103 104 def decorator(func): 105 def wrapped(expression, *args, **kwargs): 106 try: 107 return func(expression, *args, **kwargs) 108 except exceptions: 109 return expression 110 111 return wrapped 112 113 return decorator 114 115 116def rewrite_between(expression: exp.Expression) -> exp.Expression: 117 """Rewrite x between y and z to x >= y AND x <= z. 118 119 This is done because comparison simplification is only done on lt/lte/gt/gte. 120 """ 121 if isinstance(expression, exp.Between): 122 return exp.and_( 123 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 124 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 125 copy=False, 126 ) 127 return expression 128 129 130def simplify_not(expression): 131 """ 132 Demorgan's Law 133 NOT (x OR y) -> NOT x AND NOT y 134 NOT (x AND y) -> NOT x OR NOT y 135 """ 136 if isinstance(expression, exp.Not): 137 if is_null(expression.this): 138 return exp.null() 139 if isinstance(expression.this, exp.Paren): 140 condition = expression.this.unnest() 141 if isinstance(condition, exp.And): 142 return exp.or_( 143 exp.not_(condition.left, copy=False), 144 exp.not_(condition.right, copy=False), 145 copy=False, 146 ) 147 if isinstance(condition, exp.Or): 148 return exp.and_( 149 exp.not_(condition.left, copy=False), 150 exp.not_(condition.right, copy=False), 151 copy=False, 152 ) 153 if is_null(condition): 154 return exp.null() 155 if always_true(expression.this): 156 return exp.false() 157 if is_false(expression.this): 158 return exp.true() 159 if isinstance(expression.this, exp.Not): 160 # double negation 161 # NOT NOT x -> x 162 return expression.this.this 163 return expression 164 165 166def flatten(expression): 167 """ 168 A AND (B AND C) -> A AND B AND C 169 A OR (B OR C) -> A OR B OR C 170 """ 171 if isinstance(expression, exp.Connector): 172 for node in expression.args.values(): 173 child = node.unnest() 174 if isinstance(child, expression.__class__): 175 node.replace(child) 176 return expression 177 178 179def simplify_connectors(expression, root=True): 180 def _simplify_connectors(expression, left, right): 181 if left == right: 182 return left 183 if isinstance(expression, exp.And): 184 if is_false(left) or is_false(right): 185 return exp.false() 186 if is_null(left) or is_null(right): 187 return exp.null() 188 if always_true(left) and always_true(right): 189 return exp.true() 190 if always_true(left): 191 return right 192 if always_true(right): 193 return left 194 return _simplify_comparison(expression, left, right) 195 elif isinstance(expression, exp.Or): 196 if always_true(left) or always_true(right): 197 return exp.true() 198 if is_false(left) and is_false(right): 199 return exp.false() 200 if ( 201 (is_null(left) and is_null(right)) 202 or (is_null(left) and is_false(right)) 203 or (is_false(left) and is_null(right)) 204 ): 205 return exp.null() 206 if is_false(left): 207 return right 208 if is_false(right): 209 return left 210 return _simplify_comparison(expression, left, right, or_=True) 211 212 if isinstance(expression, exp.Connector): 213 return _flat_simplify(expression, _simplify_connectors, root) 214 return expression 215 216 217LT_LTE = (exp.LT, exp.LTE) 218GT_GTE = (exp.GT, exp.GTE) 219 220COMPARISONS = ( 221 *LT_LTE, 222 *GT_GTE, 223 exp.EQ, 224 exp.NEQ, 225 exp.Is, 226) 227 228INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 229 exp.LT: exp.GT, 230 exp.GT: exp.LT, 231 exp.LTE: exp.GTE, 232 exp.GTE: exp.LTE, 233} 234 235 236def _simplify_comparison(expression, left, right, or_=False): 237 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 238 ll, lr = left.args.values() 239 rl, rr = right.args.values() 240 241 largs = {ll, lr} 242 rargs = {rl, rr} 243 244 matching = largs & rargs 245 columns = {m for m in matching if isinstance(m, exp.Column)} 246 247 if matching and columns: 248 try: 249 l = first(largs - columns) 250 r = first(rargs - columns) 251 except StopIteration: 252 return expression 253 254 # make sure the comparison is always of the form x > 1 instead of 1 < x 255 if left.__class__ in INVERSE_COMPARISONS and l == ll: 256 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 257 if right.__class__ in INVERSE_COMPARISONS and r == rl: 258 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 259 260 if l.is_number and r.is_number: 261 l = float(l.name) 262 r = float(r.name) 263 elif l.is_string and r.is_string: 264 l = l.name 265 r = r.name 266 else: 267 return None 268 269 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 270 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 271 return left if (av > bv if or_ else av <= bv) else right 272 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 273 return left if (av < bv if or_ else av >= bv) else right 274 275 # we can't ever shortcut to true because the column could be null 276 if not or_: 277 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 278 if av <= bv: 279 return exp.false() 280 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 281 if av >= bv: 282 return exp.false() 283 elif isinstance(a, exp.EQ): 284 if isinstance(b, exp.LT): 285 return exp.false() if av >= bv else a 286 if isinstance(b, exp.LTE): 287 return exp.false() if av > bv else a 288 if isinstance(b, exp.GT): 289 return exp.false() if av <= bv else a 290 if isinstance(b, exp.GTE): 291 return exp.false() if av < bv else a 292 if isinstance(b, exp.NEQ): 293 return exp.false() if av == bv else a 294 return None 295 296 297def remove_complements(expression, root=True): 298 """ 299 Removing complements. 300 301 A AND NOT A -> FALSE 302 A OR NOT A -> TRUE 303 """ 304 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 305 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 306 307 for a, b in itertools.permutations(expression.flatten(), 2): 308 if is_complement(a, b): 309 return complement 310 return expression 311 312 313def uniq_sort(expression, generate, root=True): 314 """ 315 Uniq and sort a connector. 316 317 C AND A AND B AND B -> A AND B AND C 318 """ 319 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 320 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 321 flattened = tuple(expression.flatten()) 322 deduped = {generate(e): e for e in flattened} 323 arr = tuple(deduped.items()) 324 325 # check if the operands are already sorted, if not sort them 326 # A AND C AND B -> A AND B AND C 327 for i, (sql, e) in enumerate(arr[1:]): 328 if sql < arr[i][0]: 329 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 330 break 331 else: 332 # we didn't have to sort but maybe we need to dedup 333 if len(deduped) < len(flattened): 334 expression = result_func(*deduped.values(), copy=False) 335 336 return expression 337 338 339def absorb_and_eliminate(expression, root=True): 340 """ 341 absorption: 342 A AND (A OR B) -> A 343 A OR (A AND B) -> A 344 A AND (NOT A OR B) -> A AND B 345 A OR (NOT A AND B) -> A OR B 346 elimination: 347 (A AND B) OR (A AND NOT B) -> A 348 (A OR B) AND (A OR NOT B) -> A 349 """ 350 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 351 kind = exp.Or if isinstance(expression, exp.And) else exp.And 352 353 for a, b in itertools.permutations(expression.flatten(), 2): 354 if isinstance(a, kind): 355 aa, ab = a.unnest_operands() 356 357 # absorb 358 if is_complement(b, aa): 359 aa.replace(exp.true() if kind == exp.And else exp.false()) 360 elif is_complement(b, ab): 361 ab.replace(exp.true() if kind == exp.And else exp.false()) 362 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 363 a.replace(exp.false() if kind == exp.And else exp.true()) 364 elif isinstance(b, kind): 365 # eliminate 366 rhs = b.unnest_operands() 367 ba, bb = rhs 368 369 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 370 a.replace(aa) 371 b.replace(aa) 372 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 373 a.replace(ab) 374 b.replace(ab) 375 376 return expression 377 378 379def propagate_constants(expression, root=True): 380 """ 381 Propagate constants for conjunctions in DNF: 382 383 SELECT * FROM t WHERE a = b AND b = 5 becomes 384 SELECT * FROM t WHERE a = 5 AND b = 5 385 386 Reference: https://www.sqlite.org/optoverview.html 387 """ 388 389 if ( 390 isinstance(expression, exp.And) 391 and (root or not expression.same_parent) 392 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 393 ): 394 constant_mapping = {} 395 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 396 if isinstance(expr, exp.EQ): 397 l, r = expr.left, expr.right 398 399 # TODO: create a helper that can be used to detect nested literal expressions such 400 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 401 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 402 pass 403 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 404 l, r = r, l 405 else: 406 continue 407 408 constant_mapping[l] = (id(l), r) 409 410 if constant_mapping: 411 for column in find_all_in_scope(expression, exp.Column): 412 parent = column.parent 413 column_id, constant = constant_mapping.get(column) or (None, None) 414 if ( 415 column_id is not None 416 and id(column) != column_id 417 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 418 ): 419 column.replace(constant.copy()) 420 421 return expression 422 423 424INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 425 exp.DateAdd: exp.Sub, 426 exp.DateSub: exp.Add, 427 exp.DatetimeAdd: exp.Sub, 428 exp.DatetimeSub: exp.Add, 429} 430 431INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 432 **INVERSE_DATE_OPS, 433 exp.Add: exp.Sub, 434 exp.Sub: exp.Add, 435} 436 437 438def _is_number(expression: exp.Expression) -> bool: 439 return expression.is_number 440 441 442def _is_interval(expression: exp.Expression) -> bool: 443 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 444 445 446@catch(ModuleNotFoundError, UnsupportedUnit) 447def simplify_equality(expression: exp.Expression) -> exp.Expression: 448 """ 449 Use the subtraction and addition properties of equality to simplify expressions: 450 451 x + 1 = 3 becomes x = 2 452 453 There are two binary operations in the above expression: + and = 454 Here's how we reference all the operands in the code below: 455 456 l r 457 x + 1 = 3 458 a b 459 """ 460 if isinstance(expression, COMPARISONS): 461 l, r = expression.left, expression.right 462 463 if l.__class__ in INVERSE_OPS: 464 pass 465 elif r.__class__ in INVERSE_OPS: 466 l, r = r, l 467 else: 468 return expression 469 470 if r.is_number: 471 a_predicate = _is_number 472 b_predicate = _is_number 473 elif _is_date_literal(r): 474 a_predicate = _is_date_literal 475 b_predicate = _is_interval 476 else: 477 return expression 478 479 if l.__class__ in INVERSE_DATE_OPS: 480 a = l.this 481 b = l.interval() 482 else: 483 a, b = l.left, l.right 484 485 if not a_predicate(a) and b_predicate(b): 486 pass 487 elif not a_predicate(b) and b_predicate(a): 488 a, b = b, a 489 else: 490 return expression 491 492 return expression.__class__( 493 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 494 ) 495 return expression 496 497 498def simplify_literals(expression, root=True): 499 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 500 return _flat_simplify(expression, _simplify_binary, root) 501 502 if isinstance(expression, exp.Neg): 503 this = expression.this 504 if this.is_number: 505 value = this.name 506 if value[0] == "-": 507 return exp.Literal.number(value[1:]) 508 return exp.Literal.number(f"-{value}") 509 510 return expression 511 512 513def _simplify_binary(expression, a, b): 514 if isinstance(expression, exp.Is): 515 if isinstance(b, exp.Not): 516 c = b.this 517 not_ = True 518 else: 519 c = b 520 not_ = False 521 522 if is_null(c): 523 if isinstance(a, exp.Literal): 524 return exp.true() if not_ else exp.false() 525 if is_null(a): 526 return exp.false() if not_ else exp.true() 527 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 528 return None 529 elif is_null(a) or is_null(b): 530 return exp.null() 531 532 if a.is_number and b.is_number: 533 a = int(a.name) if a.is_int else Decimal(a.name) 534 b = int(b.name) if b.is_int else Decimal(b.name) 535 536 if isinstance(expression, exp.Add): 537 return exp.Literal.number(a + b) 538 if isinstance(expression, exp.Sub): 539 return exp.Literal.number(a - b) 540 if isinstance(expression, exp.Mul): 541 return exp.Literal.number(a * b) 542 if isinstance(expression, exp.Div): 543 # engines have differing int div behavior so intdiv is not safe 544 if isinstance(a, int) and isinstance(b, int): 545 return None 546 return exp.Literal.number(a / b) 547 548 boolean = eval_boolean(expression, a, b) 549 550 if boolean: 551 return boolean 552 elif a.is_string and b.is_string: 553 boolean = eval_boolean(expression, a.this, b.this) 554 555 if boolean: 556 return boolean 557 elif _is_date_literal(a) and isinstance(b, exp.Interval): 558 a, b = extract_date(a), extract_interval(b) 559 if a and b: 560 if isinstance(expression, exp.Add): 561 return date_literal(a + b) 562 if isinstance(expression, exp.Sub): 563 return date_literal(a - b) 564 elif isinstance(a, exp.Interval) and _is_date_literal(b): 565 a, b = extract_interval(a), extract_date(b) 566 # you cannot subtract a date from an interval 567 if a and b and isinstance(expression, exp.Add): 568 return date_literal(a + b) 569 570 return None 571 572 573def simplify_parens(expression): 574 if not isinstance(expression, exp.Paren): 575 return expression 576 577 this = expression.this 578 parent = expression.parent 579 580 if not isinstance(this, exp.Select) and ( 581 not isinstance(parent, (exp.Condition, exp.Binary)) 582 or isinstance(parent, exp.Paren) 583 or not isinstance(this, exp.Binary) 584 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 585 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 586 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 587 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 588 ): 589 return this 590 return expression 591 592 593CONSTANTS = ( 594 exp.Literal, 595 exp.Boolean, 596 exp.Null, 597) 598 599 600def simplify_coalesce(expression): 601 # COALESCE(x) -> x 602 if ( 603 isinstance(expression, exp.Coalesce) 604 and not expression.expressions 605 # COALESCE is also used as a Spark partitioning hint 606 and not isinstance(expression.parent, exp.Hint) 607 ): 608 return expression.this 609 610 if not isinstance(expression, COMPARISONS): 611 return expression 612 613 if isinstance(expression.left, exp.Coalesce): 614 coalesce = expression.left 615 other = expression.right 616 elif isinstance(expression.right, exp.Coalesce): 617 coalesce = expression.right 618 other = expression.left 619 else: 620 return expression 621 622 # This transformation is valid for non-constants, 623 # but it really only does anything if they are both constants. 624 if not isinstance(other, CONSTANTS): 625 return expression 626 627 # Find the first constant arg 628 for arg_index, arg in enumerate(coalesce.expressions): 629 if isinstance(arg, CONSTANTS): 630 break 631 else: 632 return expression 633 634 coalesce.set("expressions", coalesce.expressions[:arg_index]) 635 636 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 637 # since we already remove COALESCE at the top of this function. 638 coalesce = coalesce if coalesce.expressions else coalesce.this 639 640 # This expression is more complex than when we started, but it will get simplified further 641 return exp.paren( 642 exp.or_( 643 exp.and_( 644 coalesce.is_(exp.null()).not_(copy=False), 645 expression.copy(), 646 copy=False, 647 ), 648 exp.and_( 649 coalesce.is_(exp.null()), 650 type(expression)(this=arg.copy(), expression=other.copy()), 651 copy=False, 652 ), 653 copy=False, 654 ) 655 ) 656 657 658CONCATS = (exp.Concat, exp.DPipe) 659SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 660 661 662def simplify_concat(expression): 663 """Reduces all groups that contain string literals by concatenating them.""" 664 if not isinstance(expression, CONCATS) or ( 665 # We can't reduce a CONCAT_WS call if we don't statically know the separator 666 isinstance(expression, exp.ConcatWs) 667 and not expression.expressions[0].is_string 668 ): 669 return expression 670 671 if isinstance(expression, exp.ConcatWs): 672 sep_expr, *expressions = expression.expressions 673 sep = sep_expr.name 674 concat_type = exp.ConcatWs 675 else: 676 expressions = expression.expressions 677 sep = "" 678 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 679 680 new_args = [] 681 for is_string_group, group in itertools.groupby( 682 expressions or expression.flatten(), lambda e: e.is_string 683 ): 684 if is_string_group: 685 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 686 else: 687 new_args.extend(group) 688 689 if len(new_args) == 1 and new_args[0].is_string: 690 return new_args[0] 691 692 if concat_type is exp.ConcatWs: 693 new_args = [sep_expr] + new_args 694 695 return concat_type(expressions=new_args) 696 697 698DateRange = t.Tuple[datetime.date, datetime.date] 699 700 701def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 702 """ 703 Get the date range for a DATE_TRUNC equality comparison: 704 705 Example: 706 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 707 Returns: 708 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 709 """ 710 floor = date_floor(date, unit) 711 712 if date != floor: 713 # This will always be False, except for NULL values. 714 return None 715 716 return floor, floor + interval(unit) 717 718 719def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 720 """Get the logical expression for a date range""" 721 return exp.and_( 722 left >= date_literal(drange[0]), 723 left < date_literal(drange[1]), 724 copy=False, 725 ) 726 727 728def _datetrunc_eq( 729 left: exp.Expression, date: datetime.date, unit: str 730) -> t.Optional[exp.Expression]: 731 drange = _datetrunc_range(date, unit) 732 if not drange: 733 return None 734 735 return _datetrunc_eq_expression(left, drange) 736 737 738def _datetrunc_neq( 739 left: exp.Expression, date: datetime.date, unit: str 740) -> t.Optional[exp.Expression]: 741 drange = _datetrunc_range(date, unit) 742 if not drange: 743 return None 744 745 return exp.and_( 746 left < date_literal(drange[0]), 747 left >= date_literal(drange[1]), 748 copy=False, 749 ) 750 751 752DateTruncBinaryTransform = t.Callable[ 753 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 754] 755DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 756 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 757 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 758 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 759 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 760 exp.EQ: _datetrunc_eq, 761 exp.NEQ: _datetrunc_neq, 762} 763DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 764 765 766def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 767 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 768 769 770@catch(ModuleNotFoundError, UnsupportedUnit) 771def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 772 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 773 comparison = expression.__class__ 774 775 if comparison not in DATETRUNC_COMPARISONS: 776 return expression 777 778 if isinstance(expression, exp.Binary): 779 l, r = expression.left, expression.right 780 781 if _is_datetrunc_predicate(l, r): 782 pass 783 elif _is_datetrunc_predicate(r, l): 784 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 785 l, r = r, l 786 else: 787 return expression 788 789 unit = l.unit.name.lower() 790 date = extract_date(r) 791 792 if not date: 793 return expression 794 795 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 796 elif isinstance(expression, exp.In): 797 l = expression.this 798 rs = expression.expressions 799 800 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 801 unit = l.unit.name.lower() 802 803 ranges = [] 804 for r in rs: 805 date = extract_date(r) 806 if not date: 807 return expression 808 drange = _datetrunc_range(date, unit) 809 if drange: 810 ranges.append(drange) 811 812 if not ranges: 813 return expression 814 815 ranges = merge_ranges(ranges) 816 817 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 818 819 return expression 820 821 822# CROSS joins result in an empty table if the right table is empty. 823# So we can only simplify certain types of joins to CROSS. 824# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 825JOINS = { 826 ("", ""), 827 ("", "INNER"), 828 ("RIGHT", ""), 829 ("RIGHT", "OUTER"), 830} 831 832 833def remove_where_true(expression): 834 for where in expression.find_all(exp.Where): 835 if always_true(where.this): 836 where.parent.set("where", None) 837 for join in expression.find_all(exp.Join): 838 if ( 839 always_true(join.args.get("on")) 840 and not join.args.get("using") 841 and not join.args.get("method") 842 and (join.side, join.kind) in JOINS 843 ): 844 join.set("on", None) 845 join.set("side", None) 846 join.set("kind", "CROSS") 847 848 849def always_true(expression): 850 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 851 expression, exp.Literal 852 ) 853 854 855def is_complement(a, b): 856 return isinstance(b, exp.Not) and b.this == a 857 858 859def is_false(a: exp.Expression) -> bool: 860 return type(a) is exp.Boolean and not a.this 861 862 863def is_null(a: exp.Expression) -> bool: 864 return type(a) is exp.Null 865 866 867def eval_boolean(expression, a, b): 868 if isinstance(expression, (exp.EQ, exp.Is)): 869 return boolean_literal(a == b) 870 if isinstance(expression, exp.NEQ): 871 return boolean_literal(a != b) 872 if isinstance(expression, exp.GT): 873 return boolean_literal(a > b) 874 if isinstance(expression, exp.GTE): 875 return boolean_literal(a >= b) 876 if isinstance(expression, exp.LT): 877 return boolean_literal(a < b) 878 if isinstance(expression, exp.LTE): 879 return boolean_literal(a <= b) 880 return None 881 882 883def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 884 if isinstance(value, datetime.datetime): 885 return value.date() 886 if isinstance(value, datetime.date): 887 return value 888 try: 889 return datetime.datetime.fromisoformat(value).date() 890 except ValueError: 891 return None 892 893 894def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 895 if isinstance(value, datetime.datetime): 896 return value 897 if isinstance(value, datetime.date): 898 return datetime.datetime(year=value.year, month=value.month, day=value.day) 899 try: 900 return datetime.datetime.fromisoformat(value) 901 except ValueError: 902 return None 903 904 905def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 906 if not value: 907 return None 908 if to.is_type(exp.DataType.Type.DATE): 909 return cast_as_date(value) 910 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 911 return cast_as_datetime(value) 912 return None 913 914 915def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 916 if isinstance(cast, exp.Cast): 917 to = cast.to 918 elif isinstance(cast, exp.TsOrDsToDate): 919 to = exp.DataType.build(exp.DataType.Type.DATE) 920 else: 921 return None 922 923 if isinstance(cast.this, exp.Literal): 924 value: t.Any = cast.this.name 925 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 926 value = extract_date(cast.this) 927 else: 928 return None 929 return cast_value(value, to) 930 931 932def _is_date_literal(expression: exp.Expression) -> bool: 933 return extract_date(expression) is not None 934 935 936def extract_interval(expression): 937 n = int(expression.name) 938 unit = expression.text("unit").lower() 939 940 try: 941 return interval(unit, n) 942 except (UnsupportedUnit, ModuleNotFoundError): 943 return None 944 945 946def date_literal(date): 947 return exp.cast( 948 exp.Literal.string(date), 949 exp.DataType.Type.DATETIME 950 if isinstance(date, datetime.datetime) 951 else exp.DataType.Type.DATE, 952 ) 953 954 955def interval(unit: str, n: int = 1): 956 from dateutil.relativedelta import relativedelta 957 958 if unit == "year": 959 return relativedelta(years=1 * n) 960 if unit == "quarter": 961 return relativedelta(months=3 * n) 962 if unit == "month": 963 return relativedelta(months=1 * n) 964 if unit == "week": 965 return relativedelta(weeks=1 * n) 966 if unit == "day": 967 return relativedelta(days=1 * n) 968 if unit == "hour": 969 return relativedelta(hours=1 * n) 970 if unit == "minute": 971 return relativedelta(minutes=1 * n) 972 if unit == "second": 973 return relativedelta(seconds=1 * n) 974 975 raise UnsupportedUnit(f"Unsupported unit: {unit}") 976 977 978def date_floor(d: datetime.date, unit: str) -> datetime.date: 979 if unit == "year": 980 return d.replace(month=1, day=1) 981 if unit == "quarter": 982 if d.month <= 3: 983 return d.replace(month=1, day=1) 984 elif d.month <= 6: 985 return d.replace(month=4, day=1) 986 elif d.month <= 9: 987 return d.replace(month=7, day=1) 988 else: 989 return d.replace(month=10, day=1) 990 if unit == "month": 991 return d.replace(month=d.month, day=1) 992 if unit == "week": 993 # Assuming week starts on Monday (0) and ends on Sunday (6) 994 return d - datetime.timedelta(days=d.weekday()) 995 if unit == "day": 996 return d 997 998 raise UnsupportedUnit(f"Unsupported unit: {unit}") 999 1000 1001def date_ceil(d: datetime.date, unit: str) -> datetime.date: 1002 floor = date_floor(d, unit) 1003 1004 if floor == d: 1005 return d 1006 1007 return floor + interval(unit) 1008 1009 1010def boolean_literal(condition): 1011 return exp.true() if condition else exp.false() 1012 1013 1014def _flat_simplify(expression, simplifier, root=True): 1015 if root or not expression.same_parent: 1016 operands = [] 1017 queue = deque(expression.flatten(unnest=False)) 1018 size = len(queue) 1019 1020 while queue: 1021 a = queue.popleft() 1022 1023 for b in queue: 1024 result = simplifier(expression, a, b) 1025 1026 if result and result is not expression: 1027 queue.remove(b) 1028 queue.appendleft(result) 1029 break 1030 else: 1031 operands.append(a) 1032 1033 if len(operands) < size: 1034 return functools.reduce( 1035 lambda a, b: expression.__class__(this=a, expression=b), operands 1036 ) 1037 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
23def simplify(expression, constant_propagation=False): 24 """ 25 Rewrite sqlglot AST to simplify expressions. 26 27 Example: 28 >>> import sqlglot 29 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 30 >>> simplify(expression).sql() 31 'TRUE' 32 33 Args: 34 expression (sqlglot.Expression): expression to simplify 35 constant_propagation: whether or not the constant propagation rule should be used 36 37 Returns: 38 sqlglot.Expression: simplified expression 39 """ 40 41 generate = cached_generator() 42 43 # group by expressions cannot be simplified, for example 44 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 45 # the projection must exactly match the group by key 46 for group in expression.find_all(exp.Group): 47 select = group.parent 48 groups = set(group.expressions) 49 group.meta[FINAL] = True 50 51 for e in select.selects: 52 for node, *_ in e.walk(): 53 if node in groups: 54 e.meta[FINAL] = True 55 break 56 57 having = select.args.get("having") 58 if having: 59 for node, *_ in having.walk(): 60 if node in groups: 61 having.meta[FINAL] = True 62 break 63 64 def _simplify(expression, root=True): 65 if expression.meta.get(FINAL): 66 return expression 67 68 # Pre-order transformations 69 node = expression 70 node = rewrite_between(node) 71 node = uniq_sort(node, generate, root) 72 node = absorb_and_eliminate(node, root) 73 node = simplify_concat(node) 74 75 if constant_propagation: 76 node = propagate_constants(node, root) 77 78 exp.replace_children(node, lambda e: _simplify(e, False)) 79 80 # Post-order transformations 81 node = simplify_not(node) 82 node = flatten(node) 83 node = simplify_connectors(node, root) 84 node = remove_complements(node, root) 85 node = simplify_coalesce(node) 86 node.parent = expression.parent 87 node = simplify_literals(node, root) 88 node = simplify_equality(node) 89 node = simplify_parens(node) 90 node = simplify_datetrunc_predicate(node) 91 92 if root: 93 expression.replace(node) 94 95 return node 96 97 expression = while_changing(expression, _simplify) 98 remove_where_true(expression) 99 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
102def catch(*exceptions): 103 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 104 105 def decorator(func): 106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression 111 112 return wrapped 113 114 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
117def rewrite_between(expression: exp.Expression) -> exp.Expression: 118 """Rewrite x between y and z to x >= y AND x <= z. 119 120 This is done because comparison simplification is only done on lt/lte/gt/gte. 121 """ 122 if isinstance(expression, exp.Between): 123 return exp.and_( 124 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 125 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 126 copy=False, 127 ) 128 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
131def simplify_not(expression): 132 """ 133 Demorgan's Law 134 NOT (x OR y) -> NOT x AND NOT y 135 NOT (x AND y) -> NOT x OR NOT y 136 """ 137 if isinstance(expression, exp.Not): 138 if is_null(expression.this): 139 return exp.null() 140 if isinstance(expression.this, exp.Paren): 141 condition = expression.this.unnest() 142 if isinstance(condition, exp.And): 143 return exp.or_( 144 exp.not_(condition.left, copy=False), 145 exp.not_(condition.right, copy=False), 146 copy=False, 147 ) 148 if isinstance(condition, exp.Or): 149 return exp.and_( 150 exp.not_(condition.left, copy=False), 151 exp.not_(condition.right, copy=False), 152 copy=False, 153 ) 154 if is_null(condition): 155 return exp.null() 156 if always_true(expression.this): 157 return exp.false() 158 if is_false(expression.this): 159 return exp.true() 160 if isinstance(expression.this, exp.Not): 161 # double negation 162 # NOT NOT x -> x 163 return expression.this.this 164 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
167def flatten(expression): 168 """ 169 A AND (B AND C) -> A AND B AND C 170 A OR (B OR C) -> A OR B OR C 171 """ 172 if isinstance(expression, exp.Connector): 173 for node in expression.args.values(): 174 child = node.unnest() 175 if isinstance(child, expression.__class__): 176 node.replace(child) 177 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
180def simplify_connectors(expression, root=True): 181 def _simplify_connectors(expression, left, right): 182 if left == right: 183 return left 184 if isinstance(expression, exp.And): 185 if is_false(left) or is_false(right): 186 return exp.false() 187 if is_null(left) or is_null(right): 188 return exp.null() 189 if always_true(left) and always_true(right): 190 return exp.true() 191 if always_true(left): 192 return right 193 if always_true(right): 194 return left 195 return _simplify_comparison(expression, left, right) 196 elif isinstance(expression, exp.Or): 197 if always_true(left) or always_true(right): 198 return exp.true() 199 if is_false(left) and is_false(right): 200 return exp.false() 201 if ( 202 (is_null(left) and is_null(right)) 203 or (is_null(left) and is_false(right)) 204 or (is_false(left) and is_null(right)) 205 ): 206 return exp.null() 207 if is_false(left): 208 return right 209 if is_false(right): 210 return left 211 return _simplify_comparison(expression, left, right, or_=True) 212 213 if isinstance(expression, exp.Connector): 214 return _flat_simplify(expression, _simplify_connectors, root) 215 return expression
298def remove_complements(expression, root=True): 299 """ 300 Removing complements. 301 302 A AND NOT A -> FALSE 303 A OR NOT A -> TRUE 304 """ 305 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 306 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 307 308 for a, b in itertools.permutations(expression.flatten(), 2): 309 if is_complement(a, b): 310 return complement 311 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
314def uniq_sort(expression, generate, root=True): 315 """ 316 Uniq and sort a connector. 317 318 C AND A AND B AND B -> A AND B AND C 319 """ 320 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 321 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 322 flattened = tuple(expression.flatten()) 323 deduped = {generate(e): e for e in flattened} 324 arr = tuple(deduped.items()) 325 326 # check if the operands are already sorted, if not sort them 327 # A AND C AND B -> A AND B AND C 328 for i, (sql, e) in enumerate(arr[1:]): 329 if sql < arr[i][0]: 330 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 331 break 332 else: 333 # we didn't have to sort but maybe we need to dedup 334 if len(deduped) < len(flattened): 335 expression = result_func(*deduped.values(), copy=False) 336 337 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
340def absorb_and_eliminate(expression, root=True): 341 """ 342 absorption: 343 A AND (A OR B) -> A 344 A OR (A AND B) -> A 345 A AND (NOT A OR B) -> A AND B 346 A OR (NOT A AND B) -> A OR B 347 elimination: 348 (A AND B) OR (A AND NOT B) -> A 349 (A OR B) AND (A OR NOT B) -> A 350 """ 351 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 352 kind = exp.Or if isinstance(expression, exp.And) else exp.And 353 354 for a, b in itertools.permutations(expression.flatten(), 2): 355 if isinstance(a, kind): 356 aa, ab = a.unnest_operands() 357 358 # absorb 359 if is_complement(b, aa): 360 aa.replace(exp.true() if kind == exp.And else exp.false()) 361 elif is_complement(b, ab): 362 ab.replace(exp.true() if kind == exp.And else exp.false()) 363 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 364 a.replace(exp.false() if kind == exp.And else exp.true()) 365 elif isinstance(b, kind): 366 # eliminate 367 rhs = b.unnest_operands() 368 ba, bb = rhs 369 370 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 371 a.replace(aa) 372 b.replace(aa) 373 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 374 a.replace(ab) 375 b.replace(ab) 376 377 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
380def propagate_constants(expression, root=True): 381 """ 382 Propagate constants for conjunctions in DNF: 383 384 SELECT * FROM t WHERE a = b AND b = 5 becomes 385 SELECT * FROM t WHERE a = 5 AND b = 5 386 387 Reference: https://www.sqlite.org/optoverview.html 388 """ 389 390 if ( 391 isinstance(expression, exp.And) 392 and (root or not expression.same_parent) 393 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 394 ): 395 constant_mapping = {} 396 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 397 if isinstance(expr, exp.EQ): 398 l, r = expr.left, expr.right 399 400 # TODO: create a helper that can be used to detect nested literal expressions such 401 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 402 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 403 pass 404 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 405 l, r = r, l 406 else: 407 continue 408 409 constant_mapping[l] = (id(l), r) 410 411 if constant_mapping: 412 for column in find_all_in_scope(expression, exp.Column): 413 parent = column.parent 414 column_id, constant = constant_mapping.get(column) or (None, None) 415 if ( 416 column_id is not None 417 and id(column) != column_id 418 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 419 ): 420 column.replace(constant.copy()) 421 422 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
499def simplify_literals(expression, root=True): 500 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 501 return _flat_simplify(expression, _simplify_binary, root) 502 503 if isinstance(expression, exp.Neg): 504 this = expression.this 505 if this.is_number: 506 value = this.name 507 if value[0] == "-": 508 return exp.Literal.number(value[1:]) 509 return exp.Literal.number(f"-{value}") 510 511 return expression
574def simplify_parens(expression): 575 if not isinstance(expression, exp.Paren): 576 return expression 577 578 this = expression.this 579 parent = expression.parent 580 581 if not isinstance(this, exp.Select) and ( 582 not isinstance(parent, (exp.Condition, exp.Binary)) 583 or isinstance(parent, exp.Paren) 584 or not isinstance(this, exp.Binary) 585 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 586 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 587 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 588 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 589 ): 590 return this 591 return expression
601def simplify_coalesce(expression): 602 # COALESCE(x) -> x 603 if ( 604 isinstance(expression, exp.Coalesce) 605 and not expression.expressions 606 # COALESCE is also used as a Spark partitioning hint 607 and not isinstance(expression.parent, exp.Hint) 608 ): 609 return expression.this 610 611 if not isinstance(expression, COMPARISONS): 612 return expression 613 614 if isinstance(expression.left, exp.Coalesce): 615 coalesce = expression.left 616 other = expression.right 617 elif isinstance(expression.right, exp.Coalesce): 618 coalesce = expression.right 619 other = expression.left 620 else: 621 return expression 622 623 # This transformation is valid for non-constants, 624 # but it really only does anything if they are both constants. 625 if not isinstance(other, CONSTANTS): 626 return expression 627 628 # Find the first constant arg 629 for arg_index, arg in enumerate(coalesce.expressions): 630 if isinstance(arg, CONSTANTS): 631 break 632 else: 633 return expression 634 635 coalesce.set("expressions", coalesce.expressions[:arg_index]) 636 637 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 638 # since we already remove COALESCE at the top of this function. 639 coalesce = coalesce if coalesce.expressions else coalesce.this 640 641 # This expression is more complex than when we started, but it will get simplified further 642 return exp.paren( 643 exp.or_( 644 exp.and_( 645 coalesce.is_(exp.null()).not_(copy=False), 646 expression.copy(), 647 copy=False, 648 ), 649 exp.and_( 650 coalesce.is_(exp.null()), 651 type(expression)(this=arg.copy(), expression=other.copy()), 652 copy=False, 653 ), 654 copy=False, 655 ) 656 )
663def simplify_concat(expression): 664 """Reduces all groups that contain string literals by concatenating them.""" 665 if not isinstance(expression, CONCATS) or ( 666 # We can't reduce a CONCAT_WS call if we don't statically know the separator 667 isinstance(expression, exp.ConcatWs) 668 and not expression.expressions[0].is_string 669 ): 670 return expression 671 672 if isinstance(expression, exp.ConcatWs): 673 sep_expr, *expressions = expression.expressions 674 sep = sep_expr.name 675 concat_type = exp.ConcatWs 676 else: 677 expressions = expression.expressions 678 sep = "" 679 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 680 681 new_args = [] 682 for is_string_group, group in itertools.groupby( 683 expressions or expression.flatten(), lambda e: e.is_string 684 ): 685 if is_string_group: 686 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 687 else: 688 new_args.extend(group) 689 690 if len(new_args) == 1 and new_args[0].is_string: 691 return new_args[0] 692 693 if concat_type is exp.ConcatWs: 694 new_args = [sep_expr] + new_args 695 696 return concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
834def remove_where_true(expression): 835 for where in expression.find_all(exp.Where): 836 if always_true(where.this): 837 where.parent.set("where", None) 838 for join in expression.find_all(exp.Join): 839 if ( 840 always_true(join.args.get("on")) 841 and not join.args.get("using") 842 and not join.args.get("method") 843 and (join.side, join.kind) in JOINS 844 ): 845 join.set("on", None) 846 join.set("side", None) 847 join.set("kind", "CROSS")
868def eval_boolean(expression, a, b): 869 if isinstance(expression, (exp.EQ, exp.Is)): 870 return boolean_literal(a == b) 871 if isinstance(expression, exp.NEQ): 872 return boolean_literal(a != b) 873 if isinstance(expression, exp.GT): 874 return boolean_literal(a > b) 875 if isinstance(expression, exp.GTE): 876 return boolean_literal(a >= b) 877 if isinstance(expression, exp.LT): 878 return boolean_literal(a < b) 879 if isinstance(expression, exp.LTE): 880 return boolean_literal(a <= b) 881 return None
895def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 896 if isinstance(value, datetime.datetime): 897 return value 898 if isinstance(value, datetime.date): 899 return datetime.datetime(year=value.year, month=value.month, day=value.day) 900 try: 901 return datetime.datetime.fromisoformat(value) 902 except ValueError: 903 return None
906def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 907 if not value: 908 return None 909 if to.is_type(exp.DataType.Type.DATE): 910 return cast_as_date(value) 911 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 912 return cast_as_datetime(value) 913 return None
916def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 917 if isinstance(cast, exp.Cast): 918 to = cast.to 919 elif isinstance(cast, exp.TsOrDsToDate): 920 to = exp.DataType.build(exp.DataType.Type.DATE) 921 else: 922 return None 923 924 if isinstance(cast.this, exp.Literal): 925 value: t.Any = cast.this.name 926 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 927 value = extract_date(cast.this) 928 else: 929 return None 930 return cast_value(value, to)
956def interval(unit: str, n: int = 1): 957 from dateutil.relativedelta import relativedelta 958 959 if unit == "year": 960 return relativedelta(years=1 * n) 961 if unit == "quarter": 962 return relativedelta(months=3 * n) 963 if unit == "month": 964 return relativedelta(months=1 * n) 965 if unit == "week": 966 return relativedelta(weeks=1 * n) 967 if unit == "day": 968 return relativedelta(days=1 * n) 969 if unit == "hour": 970 return relativedelta(hours=1 * n) 971 if unit == "minute": 972 return relativedelta(minutes=1 * n) 973 if unit == "second": 974 return relativedelta(seconds=1 * n) 975 976 raise UnsupportedUnit(f"Unsupported unit: {unit}")
979def date_floor(d: datetime.date, unit: str) -> datetime.date: 980 if unit == "year": 981 return d.replace(month=1, day=1) 982 if unit == "quarter": 983 if d.month <= 3: 984 return d.replace(month=1, day=1) 985 elif d.month <= 6: 986 return d.replace(month=4, day=1) 987 elif d.month <= 9: 988 return d.replace(month=7, day=1) 989 else: 990 return d.replace(month=10, day=1) 991 if unit == "month": 992 return d.replace(month=d.month, day=1) 993 if unit == "week": 994 # Assuming week starts on Monday (0) and ends on Sunday (6) 995 return d - datetime.timedelta(days=d.weekday()) 996 if unit == "day": 997 return d 998 999 raise UnsupportedUnit(f"Unsupported unit: {unit}")