sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4from collections import deque 5from decimal import Decimal 6 7from sqlglot import exp 8from sqlglot.generator import cached_generator 9from sqlglot.helper import first, while_changing 10 11# Final means that an expression should not be simplified 12FINAL = "final" 13 14 15def simplify(expression): 16 """ 17 Rewrite sqlglot AST to simplify expressions. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 22 >>> simplify(expression).sql() 23 'TRUE' 24 25 Args: 26 expression (sqlglot.Expression): expression to simplify 27 Returns: 28 sqlglot.Expression: simplified expression 29 """ 30 31 generate = cached_generator() 32 33 # group by expressions cannot be simplified, for example 34 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 35 # the projection must exactly match the group by key 36 for group in expression.find_all(exp.Group): 37 select = group.parent 38 groups = set(group.expressions) 39 group.meta[FINAL] = True 40 41 for e in select.selects: 42 for node, *_ in e.walk(): 43 if node in groups: 44 e.meta[FINAL] = True 45 break 46 47 having = select.args.get("having") 48 if having: 49 for node, *_ in having.walk(): 50 if node in groups: 51 having.meta[FINAL] = True 52 break 53 54 def _simplify(expression, root=True): 55 if expression.meta.get(FINAL): 56 return expression 57 58 # Pre-order transformations 59 node = expression 60 node = rewrite_between(node) 61 node = uniq_sort(node, generate, root) 62 node = absorb_and_eliminate(node, root) 63 node = simplify_concat(node) 64 65 exp.replace_children(node, lambda e: _simplify(e, False)) 66 67 # Post-order transformations 68 node = simplify_not(node) 69 node = flatten(node) 70 node = simplify_connectors(node, root) 71 node = remove_compliments(node, root) 72 node = simplify_coalesce(node) 73 node.parent = expression.parent 74 node = simplify_literals(node, root) 75 node = simplify_parens(node) 76 77 if root: 78 expression.replace(node) 79 80 return node 81 82 expression = while_changing(expression, _simplify) 83 remove_where_true(expression) 84 return expression 85 86 87def rewrite_between(expression: exp.Expression) -> exp.Expression: 88 """Rewrite x between y and z to x >= y AND x <= z. 89 90 This is done because comparison simplification is only done on lt/lte/gt/gte. 91 """ 92 if isinstance(expression, exp.Between): 93 return exp.and_( 94 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 95 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 96 copy=False, 97 ) 98 return expression 99 100 101def simplify_not(expression): 102 """ 103 Demorgan's Law 104 NOT (x OR y) -> NOT x AND NOT y 105 NOT (x AND y) -> NOT x OR NOT y 106 """ 107 if isinstance(expression, exp.Not): 108 if is_null(expression.this): 109 return exp.null() 110 if isinstance(expression.this, exp.Paren): 111 condition = expression.this.unnest() 112 if isinstance(condition, exp.And): 113 return exp.or_( 114 exp.not_(condition.left, copy=False), 115 exp.not_(condition.right, copy=False), 116 copy=False, 117 ) 118 if isinstance(condition, exp.Or): 119 return exp.and_( 120 exp.not_(condition.left, copy=False), 121 exp.not_(condition.right, copy=False), 122 copy=False, 123 ) 124 if is_null(condition): 125 return exp.null() 126 if always_true(expression.this): 127 return exp.false() 128 if is_false(expression.this): 129 return exp.true() 130 if isinstance(expression.this, exp.Not): 131 # double negation 132 # NOT NOT x -> x 133 return expression.this.this 134 return expression 135 136 137def flatten(expression): 138 """ 139 A AND (B AND C) -> A AND B AND C 140 A OR (B OR C) -> A OR B OR C 141 """ 142 if isinstance(expression, exp.Connector): 143 for node in expression.args.values(): 144 child = node.unnest() 145 if isinstance(child, expression.__class__): 146 node.replace(child) 147 return expression 148 149 150def simplify_connectors(expression, root=True): 151 def _simplify_connectors(expression, left, right): 152 if left == right: 153 return left 154 if isinstance(expression, exp.And): 155 if is_false(left) or is_false(right): 156 return exp.false() 157 if is_null(left) or is_null(right): 158 return exp.null() 159 if always_true(left) and always_true(right): 160 return exp.true() 161 if always_true(left): 162 return right 163 if always_true(right): 164 return left 165 return _simplify_comparison(expression, left, right) 166 elif isinstance(expression, exp.Or): 167 if always_true(left) or always_true(right): 168 return exp.true() 169 if is_false(left) and is_false(right): 170 return exp.false() 171 if ( 172 (is_null(left) and is_null(right)) 173 or (is_null(left) and is_false(right)) 174 or (is_false(left) and is_null(right)) 175 ): 176 return exp.null() 177 if is_false(left): 178 return right 179 if is_false(right): 180 return left 181 return _simplify_comparison(expression, left, right, or_=True) 182 183 if isinstance(expression, exp.Connector): 184 return _flat_simplify(expression, _simplify_connectors, root) 185 return expression 186 187 188LT_LTE = (exp.LT, exp.LTE) 189GT_GTE = (exp.GT, exp.GTE) 190 191COMPARISONS = ( 192 *LT_LTE, 193 *GT_GTE, 194 exp.EQ, 195 exp.NEQ, 196 exp.Is, 197) 198 199INVERSE_COMPARISONS = { 200 exp.LT: exp.GT, 201 exp.GT: exp.LT, 202 exp.LTE: exp.GTE, 203 exp.GTE: exp.LTE, 204} 205 206 207def _simplify_comparison(expression, left, right, or_=False): 208 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 209 ll, lr = left.args.values() 210 rl, rr = right.args.values() 211 212 largs = {ll, lr} 213 rargs = {rl, rr} 214 215 matching = largs & rargs 216 columns = {m for m in matching if isinstance(m, exp.Column)} 217 218 if matching and columns: 219 try: 220 l = first(largs - columns) 221 r = first(rargs - columns) 222 except StopIteration: 223 return expression 224 225 # make sure the comparison is always of the form x > 1 instead of 1 < x 226 if left.__class__ in INVERSE_COMPARISONS and l == ll: 227 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 228 if right.__class__ in INVERSE_COMPARISONS and r == rl: 229 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 230 231 if l.is_number and r.is_number: 232 l = float(l.name) 233 r = float(r.name) 234 elif l.is_string and r.is_string: 235 l = l.name 236 r = r.name 237 else: 238 return None 239 240 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 241 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 242 return left if (av > bv if or_ else av <= bv) else right 243 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 244 return left if (av < bv if or_ else av >= bv) else right 245 246 # we can't ever shortcut to true because the column could be null 247 if not or_: 248 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 249 if av <= bv: 250 return exp.false() 251 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 252 if av >= bv: 253 return exp.false() 254 elif isinstance(a, exp.EQ): 255 if isinstance(b, exp.LT): 256 return exp.false() if av >= bv else a 257 if isinstance(b, exp.LTE): 258 return exp.false() if av > bv else a 259 if isinstance(b, exp.GT): 260 return exp.false() if av <= bv else a 261 if isinstance(b, exp.GTE): 262 return exp.false() if av < bv else a 263 if isinstance(b, exp.NEQ): 264 return exp.false() if av == bv else a 265 return None 266 267 268def remove_compliments(expression, root=True): 269 """ 270 Removing compliments. 271 272 A AND NOT A -> FALSE 273 A OR NOT A -> TRUE 274 """ 275 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 276 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 277 278 for a, b in itertools.permutations(expression.flatten(), 2): 279 if is_complement(a, b): 280 return compliment 281 return expression 282 283 284def uniq_sort(expression, generate, root=True): 285 """ 286 Uniq and sort a connector. 287 288 C AND A AND B AND B -> A AND B AND C 289 """ 290 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 291 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 292 flattened = tuple(expression.flatten()) 293 deduped = {generate(e): e for e in flattened} 294 arr = tuple(deduped.items()) 295 296 # check if the operands are already sorted, if not sort them 297 # A AND C AND B -> A AND B AND C 298 for i, (sql, e) in enumerate(arr[1:]): 299 if sql < arr[i][0]: 300 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 301 break 302 else: 303 # we didn't have to sort but maybe we need to dedup 304 if len(deduped) < len(flattened): 305 expression = result_func(*deduped.values(), copy=False) 306 307 return expression 308 309 310def absorb_and_eliminate(expression, root=True): 311 """ 312 absorption: 313 A AND (A OR B) -> A 314 A OR (A AND B) -> A 315 A AND (NOT A OR B) -> A AND B 316 A OR (NOT A AND B) -> A OR B 317 elimination: 318 (A AND B) OR (A AND NOT B) -> A 319 (A OR B) AND (A OR NOT B) -> A 320 """ 321 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 322 kind = exp.Or if isinstance(expression, exp.And) else exp.And 323 324 for a, b in itertools.permutations(expression.flatten(), 2): 325 if isinstance(a, kind): 326 aa, ab = a.unnest_operands() 327 328 # absorb 329 if is_complement(b, aa): 330 aa.replace(exp.true() if kind == exp.And else exp.false()) 331 elif is_complement(b, ab): 332 ab.replace(exp.true() if kind == exp.And else exp.false()) 333 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 334 a.replace(exp.false() if kind == exp.And else exp.true()) 335 elif isinstance(b, kind): 336 # eliminate 337 rhs = b.unnest_operands() 338 ba, bb = rhs 339 340 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 341 a.replace(aa) 342 b.replace(aa) 343 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 344 a.replace(ab) 345 b.replace(ab) 346 347 return expression 348 349 350def simplify_literals(expression, root=True): 351 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 352 return _flat_simplify(expression, _simplify_binary, root) 353 elif isinstance(expression, exp.Neg): 354 this = expression.this 355 if this.is_number: 356 value = this.name 357 if value[0] == "-": 358 return exp.Literal.number(value[1:]) 359 return exp.Literal.number(f"-{value}") 360 361 return expression 362 363 364def _simplify_binary(expression, a, b): 365 if isinstance(expression, exp.Is): 366 if isinstance(b, exp.Not): 367 c = b.this 368 not_ = True 369 else: 370 c = b 371 not_ = False 372 373 if is_null(c): 374 if isinstance(a, exp.Literal): 375 return exp.true() if not_ else exp.false() 376 if is_null(a): 377 return exp.false() if not_ else exp.true() 378 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 379 return None 380 elif is_null(a) or is_null(b): 381 return exp.null() 382 383 if a.is_number and b.is_number: 384 a = int(a.name) if a.is_int else Decimal(a.name) 385 b = int(b.name) if b.is_int else Decimal(b.name) 386 387 if isinstance(expression, exp.Add): 388 return exp.Literal.number(a + b) 389 if isinstance(expression, exp.Sub): 390 return exp.Literal.number(a - b) 391 if isinstance(expression, exp.Mul): 392 return exp.Literal.number(a * b) 393 if isinstance(expression, exp.Div): 394 # engines have differing int div behavior so intdiv is not safe 395 if isinstance(a, int) and isinstance(b, int): 396 return None 397 return exp.Literal.number(a / b) 398 399 boolean = eval_boolean(expression, a, b) 400 401 if boolean: 402 return boolean 403 elif a.is_string and b.is_string: 404 boolean = eval_boolean(expression, a.this, b.this) 405 406 if boolean: 407 return boolean 408 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 409 a, b = extract_date(a), extract_interval(b) 410 if a and b: 411 if isinstance(expression, exp.Add): 412 return date_literal(a + b) 413 if isinstance(expression, exp.Sub): 414 return date_literal(a - b) 415 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 416 a, b = extract_interval(a), extract_date(b) 417 # you cannot subtract a date from an interval 418 if a and b and isinstance(expression, exp.Add): 419 return date_literal(a + b) 420 421 return None 422 423 424def simplify_parens(expression): 425 if not isinstance(expression, exp.Paren): 426 return expression 427 428 this = expression.this 429 parent = expression.parent 430 431 if not isinstance(this, exp.Select) and ( 432 not isinstance(parent, (exp.Condition, exp.Binary)) 433 or isinstance(this, exp.Predicate) 434 or isinstance(parent, exp.Paren) 435 or not isinstance(this, exp.Binary) 436 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 437 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 438 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 439 ): 440 return this 441 return expression 442 443 444CONSTANTS = ( 445 exp.Literal, 446 exp.Boolean, 447 exp.Null, 448) 449 450 451def simplify_coalesce(expression): 452 # COALESCE(x) -> x 453 if ( 454 isinstance(expression, exp.Coalesce) 455 and not expression.expressions 456 # COALESCE is also used as a Spark partitioning hint 457 and not isinstance(expression.parent, exp.Hint) 458 ): 459 return expression.this 460 461 if not isinstance(expression, COMPARISONS): 462 return expression 463 464 if isinstance(expression.left, exp.Coalesce): 465 coalesce = expression.left 466 other = expression.right 467 elif isinstance(expression.right, exp.Coalesce): 468 coalesce = expression.right 469 other = expression.left 470 else: 471 return expression 472 473 # This transformation is valid for non-constants, 474 # but it really only does anything if they are both constants. 475 if not isinstance(other, CONSTANTS): 476 return expression 477 478 # Find the first constant arg 479 for arg_index, arg in enumerate(coalesce.expressions): 480 if isinstance(arg, CONSTANTS): 481 break 482 else: 483 return expression 484 485 coalesce.set("expressions", coalesce.expressions[:arg_index]) 486 487 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 488 # since we already remove COALESCE at the top of this function. 489 coalesce = coalesce if coalesce.expressions else coalesce.this 490 491 # This expression is more complex than when we started, but it will get simplified further 492 return exp.paren( 493 exp.or_( 494 exp.and_( 495 coalesce.is_(exp.null()).not_(copy=False), 496 expression.copy(), 497 copy=False, 498 ), 499 exp.and_( 500 coalesce.is_(exp.null()), 501 type(expression)(this=arg.copy(), expression=other.copy()), 502 copy=False, 503 ), 504 copy=False, 505 ) 506 ) 507 508 509CONCATS = (exp.Concat, exp.DPipe) 510SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 511 512 513def simplify_concat(expression): 514 """Reduces all groups that contain string literals by concatenating them.""" 515 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 516 return expression 517 518 new_args = [] 519 for is_string_group, group in itertools.groupby( 520 expression.expressions or expression.flatten(), lambda e: e.is_string 521 ): 522 if is_string_group: 523 new_args.append(exp.Literal.string("".join(string.name for string in group))) 524 else: 525 new_args.extend(group) 526 527 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 528 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 529 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 530 531 532# CROSS joins result in an empty table if the right table is empty. 533# So we can only simplify certain types of joins to CROSS. 534# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 535JOINS = { 536 ("", ""), 537 ("", "INNER"), 538 ("RIGHT", ""), 539 ("RIGHT", "OUTER"), 540} 541 542 543def remove_where_true(expression): 544 for where in expression.find_all(exp.Where): 545 if always_true(where.this): 546 where.parent.set("where", None) 547 for join in expression.find_all(exp.Join): 548 if ( 549 always_true(join.args.get("on")) 550 and not join.args.get("using") 551 and not join.args.get("method") 552 and (join.side, join.kind) in JOINS 553 ): 554 join.set("on", None) 555 join.set("side", None) 556 join.set("kind", "CROSS") 557 558 559def always_true(expression): 560 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 561 expression, exp.Literal 562 ) 563 564 565def is_complement(a, b): 566 return isinstance(b, exp.Not) and b.this == a 567 568 569def is_false(a: exp.Expression) -> bool: 570 return type(a) is exp.Boolean and not a.this 571 572 573def is_null(a: exp.Expression) -> bool: 574 return type(a) is exp.Null 575 576 577def eval_boolean(expression, a, b): 578 if isinstance(expression, (exp.EQ, exp.Is)): 579 return boolean_literal(a == b) 580 if isinstance(expression, exp.NEQ): 581 return boolean_literal(a != b) 582 if isinstance(expression, exp.GT): 583 return boolean_literal(a > b) 584 if isinstance(expression, exp.GTE): 585 return boolean_literal(a >= b) 586 if isinstance(expression, exp.LT): 587 return boolean_literal(a < b) 588 if isinstance(expression, exp.LTE): 589 return boolean_literal(a <= b) 590 return None 591 592 593def extract_date(cast): 594 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 595 # so in that case we can't extract the date. 596 try: 597 if cast.args["to"].this == exp.DataType.Type.DATE: 598 return datetime.date.fromisoformat(cast.name) 599 if cast.args["to"].this == exp.DataType.Type.DATETIME: 600 return datetime.datetime.fromisoformat(cast.name) 601 except ValueError: 602 return None 603 604 605def extract_interval(interval): 606 try: 607 from dateutil.relativedelta import relativedelta # type: ignore 608 except ModuleNotFoundError: 609 return None 610 611 n = int(interval.name) 612 unit = interval.text("unit").lower() 613 614 if unit == "year": 615 return relativedelta(years=n) 616 if unit == "month": 617 return relativedelta(months=n) 618 if unit == "week": 619 return relativedelta(weeks=n) 620 if unit == "day": 621 return relativedelta(days=n) 622 return None 623 624 625def date_literal(date): 626 return exp.cast( 627 exp.Literal.string(date), 628 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 629 ) 630 631 632def boolean_literal(condition): 633 return exp.true() if condition else exp.false() 634 635 636def _flat_simplify(expression, simplifier, root=True): 637 if root or not expression.same_parent: 638 operands = [] 639 queue = deque(expression.flatten(unnest=False)) 640 size = len(queue) 641 642 while queue: 643 a = queue.popleft() 644 645 for b in queue: 646 result = simplifier(expression, a, b) 647 648 if result: 649 queue.remove(b) 650 queue.appendleft(result) 651 break 652 else: 653 operands.append(a) 654 655 if len(operands) < size: 656 return functools.reduce( 657 lambda a, b: expression.__class__(this=a, expression=b), operands 658 ) 659 return expression
16def simplify(expression): 17 """ 18 Rewrite sqlglot AST to simplify expressions. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 23 >>> simplify(expression).sql() 24 'TRUE' 25 26 Args: 27 expression (sqlglot.Expression): expression to simplify 28 Returns: 29 sqlglot.Expression: simplified expression 30 """ 31 32 generate = cached_generator() 33 34 # group by expressions cannot be simplified, for example 35 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 36 # the projection must exactly match the group by key 37 for group in expression.find_all(exp.Group): 38 select = group.parent 39 groups = set(group.expressions) 40 group.meta[FINAL] = True 41 42 for e in select.selects: 43 for node, *_ in e.walk(): 44 if node in groups: 45 e.meta[FINAL] = True 46 break 47 48 having = select.args.get("having") 49 if having: 50 for node, *_ in having.walk(): 51 if node in groups: 52 having.meta[FINAL] = True 53 break 54 55 def _simplify(expression, root=True): 56 if expression.meta.get(FINAL): 57 return expression 58 59 # Pre-order transformations 60 node = expression 61 node = rewrite_between(node) 62 node = uniq_sort(node, generate, root) 63 node = absorb_and_eliminate(node, root) 64 node = simplify_concat(node) 65 66 exp.replace_children(node, lambda e: _simplify(e, False)) 67 68 # Post-order transformations 69 node = simplify_not(node) 70 node = flatten(node) 71 node = simplify_connectors(node, root) 72 node = remove_compliments(node, root) 73 node = simplify_coalesce(node) 74 node.parent = expression.parent 75 node = simplify_literals(node, root) 76 node = simplify_parens(node) 77 78 if root: 79 expression.replace(node) 80 81 return node 82 83 expression = while_changing(expression, _simplify) 84 remove_where_true(expression) 85 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
88def rewrite_between(expression: exp.Expression) -> exp.Expression: 89 """Rewrite x between y and z to x >= y AND x <= z. 90 91 This is done because comparison simplification is only done on lt/lte/gt/gte. 92 """ 93 if isinstance(expression, exp.Between): 94 return exp.and_( 95 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 96 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 97 copy=False, 98 ) 99 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
102def simplify_not(expression): 103 """ 104 Demorgan's Law 105 NOT (x OR y) -> NOT x AND NOT y 106 NOT (x AND y) -> NOT x OR NOT y 107 """ 108 if isinstance(expression, exp.Not): 109 if is_null(expression.this): 110 return exp.null() 111 if isinstance(expression.this, exp.Paren): 112 condition = expression.this.unnest() 113 if isinstance(condition, exp.And): 114 return exp.or_( 115 exp.not_(condition.left, copy=False), 116 exp.not_(condition.right, copy=False), 117 copy=False, 118 ) 119 if isinstance(condition, exp.Or): 120 return exp.and_( 121 exp.not_(condition.left, copy=False), 122 exp.not_(condition.right, copy=False), 123 copy=False, 124 ) 125 if is_null(condition): 126 return exp.null() 127 if always_true(expression.this): 128 return exp.false() 129 if is_false(expression.this): 130 return exp.true() 131 if isinstance(expression.this, exp.Not): 132 # double negation 133 # NOT NOT x -> x 134 return expression.this.this 135 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
138def flatten(expression): 139 """ 140 A AND (B AND C) -> A AND B AND C 141 A OR (B OR C) -> A OR B OR C 142 """ 143 if isinstance(expression, exp.Connector): 144 for node in expression.args.values(): 145 child = node.unnest() 146 if isinstance(child, expression.__class__): 147 node.replace(child) 148 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
151def simplify_connectors(expression, root=True): 152 def _simplify_connectors(expression, left, right): 153 if left == right: 154 return left 155 if isinstance(expression, exp.And): 156 if is_false(left) or is_false(right): 157 return exp.false() 158 if is_null(left) or is_null(right): 159 return exp.null() 160 if always_true(left) and always_true(right): 161 return exp.true() 162 if always_true(left): 163 return right 164 if always_true(right): 165 return left 166 return _simplify_comparison(expression, left, right) 167 elif isinstance(expression, exp.Or): 168 if always_true(left) or always_true(right): 169 return exp.true() 170 if is_false(left) and is_false(right): 171 return exp.false() 172 if ( 173 (is_null(left) and is_null(right)) 174 or (is_null(left) and is_false(right)) 175 or (is_false(left) and is_null(right)) 176 ): 177 return exp.null() 178 if is_false(left): 179 return right 180 if is_false(right): 181 return left 182 return _simplify_comparison(expression, left, right, or_=True) 183 184 if isinstance(expression, exp.Connector): 185 return _flat_simplify(expression, _simplify_connectors, root) 186 return expression
269def remove_compliments(expression, root=True): 270 """ 271 Removing compliments. 272 273 A AND NOT A -> FALSE 274 A OR NOT A -> TRUE 275 """ 276 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 277 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 278 279 for a, b in itertools.permutations(expression.flatten(), 2): 280 if is_complement(a, b): 281 return compliment 282 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
285def uniq_sort(expression, generate, root=True): 286 """ 287 Uniq and sort a connector. 288 289 C AND A AND B AND B -> A AND B AND C 290 """ 291 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 292 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 293 flattened = tuple(expression.flatten()) 294 deduped = {generate(e): e for e in flattened} 295 arr = tuple(deduped.items()) 296 297 # check if the operands are already sorted, if not sort them 298 # A AND C AND B -> A AND B AND C 299 for i, (sql, e) in enumerate(arr[1:]): 300 if sql < arr[i][0]: 301 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 302 break 303 else: 304 # we didn't have to sort but maybe we need to dedup 305 if len(deduped) < len(flattened): 306 expression = result_func(*deduped.values(), copy=False) 307 308 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
311def absorb_and_eliminate(expression, root=True): 312 """ 313 absorption: 314 A AND (A OR B) -> A 315 A OR (A AND B) -> A 316 A AND (NOT A OR B) -> A AND B 317 A OR (NOT A AND B) -> A OR B 318 elimination: 319 (A AND B) OR (A AND NOT B) -> A 320 (A OR B) AND (A OR NOT B) -> A 321 """ 322 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 323 kind = exp.Or if isinstance(expression, exp.And) else exp.And 324 325 for a, b in itertools.permutations(expression.flatten(), 2): 326 if isinstance(a, kind): 327 aa, ab = a.unnest_operands() 328 329 # absorb 330 if is_complement(b, aa): 331 aa.replace(exp.true() if kind == exp.And else exp.false()) 332 elif is_complement(b, ab): 333 ab.replace(exp.true() if kind == exp.And else exp.false()) 334 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 335 a.replace(exp.false() if kind == exp.And else exp.true()) 336 elif isinstance(b, kind): 337 # eliminate 338 rhs = b.unnest_operands() 339 ba, bb = rhs 340 341 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 342 a.replace(aa) 343 b.replace(aa) 344 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 345 a.replace(ab) 346 b.replace(ab) 347 348 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
351def simplify_literals(expression, root=True): 352 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 353 return _flat_simplify(expression, _simplify_binary, root) 354 elif isinstance(expression, exp.Neg): 355 this = expression.this 356 if this.is_number: 357 value = this.name 358 if value[0] == "-": 359 return exp.Literal.number(value[1:]) 360 return exp.Literal.number(f"-{value}") 361 362 return expression
425def simplify_parens(expression): 426 if not isinstance(expression, exp.Paren): 427 return expression 428 429 this = expression.this 430 parent = expression.parent 431 432 if not isinstance(this, exp.Select) and ( 433 not isinstance(parent, (exp.Condition, exp.Binary)) 434 or isinstance(this, exp.Predicate) 435 or isinstance(parent, exp.Paren) 436 or not isinstance(this, exp.Binary) 437 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 438 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 439 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 440 ): 441 return this 442 return expression
452def simplify_coalesce(expression): 453 # COALESCE(x) -> x 454 if ( 455 isinstance(expression, exp.Coalesce) 456 and not expression.expressions 457 # COALESCE is also used as a Spark partitioning hint 458 and not isinstance(expression.parent, exp.Hint) 459 ): 460 return expression.this 461 462 if not isinstance(expression, COMPARISONS): 463 return expression 464 465 if isinstance(expression.left, exp.Coalesce): 466 coalesce = expression.left 467 other = expression.right 468 elif isinstance(expression.right, exp.Coalesce): 469 coalesce = expression.right 470 other = expression.left 471 else: 472 return expression 473 474 # This transformation is valid for non-constants, 475 # but it really only does anything if they are both constants. 476 if not isinstance(other, CONSTANTS): 477 return expression 478 479 # Find the first constant arg 480 for arg_index, arg in enumerate(coalesce.expressions): 481 if isinstance(arg, CONSTANTS): 482 break 483 else: 484 return expression 485 486 coalesce.set("expressions", coalesce.expressions[:arg_index]) 487 488 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 489 # since we already remove COALESCE at the top of this function. 490 coalesce = coalesce if coalesce.expressions else coalesce.this 491 492 # This expression is more complex than when we started, but it will get simplified further 493 return exp.paren( 494 exp.or_( 495 exp.and_( 496 coalesce.is_(exp.null()).not_(copy=False), 497 expression.copy(), 498 copy=False, 499 ), 500 exp.and_( 501 coalesce.is_(exp.null()), 502 type(expression)(this=arg.copy(), expression=other.copy()), 503 copy=False, 504 ), 505 copy=False, 506 ) 507 )
514def simplify_concat(expression): 515 """Reduces all groups that contain string literals by concatenating them.""" 516 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 517 return expression 518 519 new_args = [] 520 for is_string_group, group in itertools.groupby( 521 expression.expressions or expression.flatten(), lambda e: e.is_string 522 ): 523 if is_string_group: 524 new_args.append(exp.Literal.string("".join(string.name for string in group))) 525 else: 526 new_args.extend(group) 527 528 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 529 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 530 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
544def remove_where_true(expression): 545 for where in expression.find_all(exp.Where): 546 if always_true(where.this): 547 where.parent.set("where", None) 548 for join in expression.find_all(exp.Join): 549 if ( 550 always_true(join.args.get("on")) 551 and not join.args.get("using") 552 and not join.args.get("method") 553 and (join.side, join.kind) in JOINS 554 ): 555 join.set("on", None) 556 join.set("side", None) 557 join.set("kind", "CROSS")
578def eval_boolean(expression, a, b): 579 if isinstance(expression, (exp.EQ, exp.Is)): 580 return boolean_literal(a == b) 581 if isinstance(expression, exp.NEQ): 582 return boolean_literal(a != b) 583 if isinstance(expression, exp.GT): 584 return boolean_literal(a > b) 585 if isinstance(expression, exp.GTE): 586 return boolean_literal(a >= b) 587 if isinstance(expression, exp.LT): 588 return boolean_literal(a < b) 589 if isinstance(expression, exp.LTE): 590 return boolean_literal(a <= b) 591 return None
594def extract_date(cast): 595 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 596 # so in that case we can't extract the date. 597 try: 598 if cast.args["to"].this == exp.DataType.Type.DATE: 599 return datetime.date.fromisoformat(cast.name) 600 if cast.args["to"].this == exp.DataType.Type.DATETIME: 601 return datetime.datetime.fromisoformat(cast.name) 602 except ValueError: 603 return None
606def extract_interval(interval): 607 try: 608 from dateutil.relativedelta import relativedelta # type: ignore 609 except ModuleNotFoundError: 610 return None 611 612 n = int(interval.name) 613 unit = interval.text("unit").lower() 614 615 if unit == "year": 616 return relativedelta(years=n) 617 if unit == "month": 618 return relativedelta(months=n) 619 if unit == "week": 620 return relativedelta(weeks=n) 621 if unit == "day": 622 return relativedelta(days=n) 623 return None