sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4from collections import deque 5from decimal import Decimal 6 7from sqlglot import exp 8from sqlglot.generator import cached_generator 9from sqlglot.helper import first, while_changing 10 11# Final means that an expression should not be simplified 12FINAL = "final" 13 14 15def simplify(expression): 16 """ 17 Rewrite sqlglot AST to simplify expressions. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 22 >>> simplify(expression).sql() 23 'TRUE' 24 25 Args: 26 expression (sqlglot.Expression): expression to simplify 27 Returns: 28 sqlglot.Expression: simplified expression 29 """ 30 31 generate = cached_generator() 32 33 # group by expressions cannot be simplified, for example 34 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 35 # the projection must exactly match the group by key 36 for group in expression.find_all(exp.Group): 37 select = group.parent 38 groups = set(group.expressions) 39 group.meta[FINAL] = True 40 41 for e in select.selects: 42 for node, *_ in e.walk(): 43 if node in groups: 44 e.meta[FINAL] = True 45 break 46 47 having = select.args.get("having") 48 if having: 49 for node, *_ in having.walk(): 50 if node in groups: 51 having.meta[FINAL] = True 52 break 53 54 def _simplify(expression, root=True): 55 if expression.meta.get(FINAL): 56 return expression 57 58 # Pre-order transformations 59 node = expression 60 node = rewrite_between(node) 61 node = uniq_sort(node, generate, root) 62 node = absorb_and_eliminate(node, root) 63 node = simplify_concat(node) 64 65 exp.replace_children(node, lambda e: _simplify(e, False)) 66 67 # Post-order transformations 68 node = simplify_not(node) 69 node = flatten(node) 70 node = simplify_connectors(node, root) 71 node = remove_compliments(node, root) 72 node.parent = expression.parent 73 node = simplify_literals(node, root) 74 node = simplify_parens(node) 75 node = simplify_coalesce(node) 76 77 if root: 78 expression.replace(node) 79 80 return node 81 82 expression = while_changing(expression, _simplify) 83 remove_where_true(expression) 84 return expression 85 86 87def rewrite_between(expression: exp.Expression) -> exp.Expression: 88 """Rewrite x between y and z to x >= y AND x <= z. 89 90 This is done because comparison simplification is only done on lt/lte/gt/gte. 91 """ 92 if isinstance(expression, exp.Between): 93 return exp.and_( 94 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 95 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 96 copy=False, 97 ) 98 return expression 99 100 101def simplify_not(expression): 102 """ 103 Demorgan's Law 104 NOT (x OR y) -> NOT x AND NOT y 105 NOT (x AND y) -> NOT x OR NOT y 106 """ 107 if isinstance(expression, exp.Not): 108 if is_null(expression.this): 109 return exp.null() 110 if isinstance(expression.this, exp.Paren): 111 condition = expression.this.unnest() 112 if isinstance(condition, exp.And): 113 return exp.or_( 114 exp.not_(condition.left, copy=False), 115 exp.not_(condition.right, copy=False), 116 copy=False, 117 ) 118 if isinstance(condition, exp.Or): 119 return exp.and_( 120 exp.not_(condition.left, copy=False), 121 exp.not_(condition.right, copy=False), 122 copy=False, 123 ) 124 if is_null(condition): 125 return exp.null() 126 if always_true(expression.this): 127 return exp.false() 128 if is_false(expression.this): 129 return exp.true() 130 if isinstance(expression.this, exp.Not): 131 # double negation 132 # NOT NOT x -> x 133 return expression.this.this 134 return expression 135 136 137def flatten(expression): 138 """ 139 A AND (B AND C) -> A AND B AND C 140 A OR (B OR C) -> A OR B OR C 141 """ 142 if isinstance(expression, exp.Connector): 143 for node in expression.args.values(): 144 child = node.unnest() 145 if isinstance(child, expression.__class__): 146 node.replace(child) 147 return expression 148 149 150def simplify_connectors(expression, root=True): 151 def _simplify_connectors(expression, left, right): 152 if left == right: 153 return left 154 if isinstance(expression, exp.And): 155 if is_false(left) or is_false(right): 156 return exp.false() 157 if is_null(left) or is_null(right): 158 return exp.null() 159 if always_true(left) and always_true(right): 160 return exp.true() 161 if always_true(left): 162 return right 163 if always_true(right): 164 return left 165 return _simplify_comparison(expression, left, right) 166 elif isinstance(expression, exp.Or): 167 if always_true(left) or always_true(right): 168 return exp.true() 169 if is_false(left) and is_false(right): 170 return exp.false() 171 if ( 172 (is_null(left) and is_null(right)) 173 or (is_null(left) and is_false(right)) 174 or (is_false(left) and is_null(right)) 175 ): 176 return exp.null() 177 if is_false(left): 178 return right 179 if is_false(right): 180 return left 181 return _simplify_comparison(expression, left, right, or_=True) 182 183 if isinstance(expression, exp.Connector): 184 return _flat_simplify(expression, _simplify_connectors, root) 185 return expression 186 187 188LT_LTE = (exp.LT, exp.LTE) 189GT_GTE = (exp.GT, exp.GTE) 190 191COMPARISONS = ( 192 *LT_LTE, 193 *GT_GTE, 194 exp.EQ, 195 exp.NEQ, 196 exp.Is, 197) 198 199INVERSE_COMPARISONS = { 200 exp.LT: exp.GT, 201 exp.GT: exp.LT, 202 exp.LTE: exp.GTE, 203 exp.GTE: exp.LTE, 204} 205 206 207def _simplify_comparison(expression, left, right, or_=False): 208 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 209 ll, lr = left.args.values() 210 rl, rr = right.args.values() 211 212 largs = {ll, lr} 213 rargs = {rl, rr} 214 215 matching = largs & rargs 216 columns = {m for m in matching if isinstance(m, exp.Column)} 217 218 if matching and columns: 219 try: 220 l = first(largs - columns) 221 r = first(rargs - columns) 222 except StopIteration: 223 return expression 224 225 # make sure the comparison is always of the form x > 1 instead of 1 < x 226 if left.__class__ in INVERSE_COMPARISONS and l == ll: 227 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 228 if right.__class__ in INVERSE_COMPARISONS and r == rl: 229 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 230 231 if l.is_number and r.is_number: 232 l = float(l.name) 233 r = float(r.name) 234 elif l.is_string and r.is_string: 235 l = l.name 236 r = r.name 237 else: 238 return None 239 240 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 241 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 242 return left if (av > bv if or_ else av <= bv) else right 243 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 244 return left if (av < bv if or_ else av >= bv) else right 245 246 # we can't ever shortcut to true because the column could be null 247 if not or_: 248 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 249 if av <= bv: 250 return exp.false() 251 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 252 if av >= bv: 253 return exp.false() 254 elif isinstance(a, exp.EQ): 255 if isinstance(b, exp.LT): 256 return exp.false() if av >= bv else a 257 if isinstance(b, exp.LTE): 258 return exp.false() if av > bv else a 259 if isinstance(b, exp.GT): 260 return exp.false() if av <= bv else a 261 if isinstance(b, exp.GTE): 262 return exp.false() if av < bv else a 263 if isinstance(b, exp.NEQ): 264 return exp.false() if av == bv else a 265 return None 266 267 268def remove_compliments(expression, root=True): 269 """ 270 Removing compliments. 271 272 A AND NOT A -> FALSE 273 A OR NOT A -> TRUE 274 """ 275 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 276 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 277 278 for a, b in itertools.permutations(expression.flatten(), 2): 279 if is_complement(a, b): 280 return compliment 281 return expression 282 283 284def uniq_sort(expression, generate, root=True): 285 """ 286 Uniq and sort a connector. 287 288 C AND A AND B AND B -> A AND B AND C 289 """ 290 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 291 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 292 flattened = tuple(expression.flatten()) 293 deduped = {generate(e): e for e in flattened} 294 arr = tuple(deduped.items()) 295 296 # check if the operands are already sorted, if not sort them 297 # A AND C AND B -> A AND B AND C 298 for i, (sql, e) in enumerate(arr[1:]): 299 if sql < arr[i][0]: 300 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 301 break 302 else: 303 # we didn't have to sort but maybe we need to dedup 304 if len(deduped) < len(flattened): 305 expression = result_func(*deduped.values(), copy=False) 306 307 return expression 308 309 310def absorb_and_eliminate(expression, root=True): 311 """ 312 absorption: 313 A AND (A OR B) -> A 314 A OR (A AND B) -> A 315 A AND (NOT A OR B) -> A AND B 316 A OR (NOT A AND B) -> A OR B 317 elimination: 318 (A AND B) OR (A AND NOT B) -> A 319 (A OR B) AND (A OR NOT B) -> A 320 """ 321 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 322 kind = exp.Or if isinstance(expression, exp.And) else exp.And 323 324 for a, b in itertools.permutations(expression.flatten(), 2): 325 if isinstance(a, kind): 326 aa, ab = a.unnest_operands() 327 328 # absorb 329 if is_complement(b, aa): 330 aa.replace(exp.true() if kind == exp.And else exp.false()) 331 elif is_complement(b, ab): 332 ab.replace(exp.true() if kind == exp.And else exp.false()) 333 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 334 a.replace(exp.false() if kind == exp.And else exp.true()) 335 elif isinstance(b, kind): 336 # eliminate 337 rhs = b.unnest_operands() 338 ba, bb = rhs 339 340 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 341 a.replace(aa) 342 b.replace(aa) 343 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 344 a.replace(ab) 345 b.replace(ab) 346 347 return expression 348 349 350def simplify_literals(expression, root=True): 351 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 352 return _flat_simplify(expression, _simplify_binary, root) 353 elif isinstance(expression, exp.Neg): 354 this = expression.this 355 if this.is_number: 356 value = this.name 357 if value[0] == "-": 358 return exp.Literal.number(value[1:]) 359 return exp.Literal.number(f"-{value}") 360 361 return expression 362 363 364def _simplify_binary(expression, a, b): 365 if isinstance(expression, exp.Is): 366 if isinstance(b, exp.Not): 367 c = b.this 368 not_ = True 369 else: 370 c = b 371 not_ = False 372 373 if is_null(c): 374 if isinstance(a, exp.Literal): 375 return exp.true() if not_ else exp.false() 376 if is_null(a): 377 return exp.false() if not_ else exp.true() 378 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 379 return None 380 elif is_null(a) or is_null(b): 381 return exp.null() 382 383 if a.is_number and b.is_number: 384 a = int(a.name) if a.is_int else Decimal(a.name) 385 b = int(b.name) if b.is_int else Decimal(b.name) 386 387 if isinstance(expression, exp.Add): 388 return exp.Literal.number(a + b) 389 if isinstance(expression, exp.Sub): 390 return exp.Literal.number(a - b) 391 if isinstance(expression, exp.Mul): 392 return exp.Literal.number(a * b) 393 if isinstance(expression, exp.Div): 394 # engines have differing int div behavior so intdiv is not safe 395 if isinstance(a, int) and isinstance(b, int): 396 return None 397 return exp.Literal.number(a / b) 398 399 boolean = eval_boolean(expression, a, b) 400 401 if boolean: 402 return boolean 403 elif a.is_string and b.is_string: 404 boolean = eval_boolean(expression, a.this, b.this) 405 406 if boolean: 407 return boolean 408 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 409 a, b = extract_date(a), extract_interval(b) 410 if a and b: 411 if isinstance(expression, exp.Add): 412 return date_literal(a + b) 413 if isinstance(expression, exp.Sub): 414 return date_literal(a - b) 415 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 416 a, b = extract_interval(a), extract_date(b) 417 # you cannot subtract a date from an interval 418 if a and b and isinstance(expression, exp.Add): 419 return date_literal(a + b) 420 421 return None 422 423 424def simplify_parens(expression): 425 if not isinstance(expression, exp.Paren): 426 return expression 427 428 this = expression.this 429 parent = expression.parent 430 431 if not isinstance(this, exp.Select) and ( 432 not isinstance(parent, (exp.Condition, exp.Binary)) 433 or isinstance(this, exp.Predicate) 434 or not isinstance(this, exp.Binary) 435 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 436 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 437 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 438 ): 439 return expression.this 440 return expression 441 442 443CONSTANTS = ( 444 exp.Literal, 445 exp.Boolean, 446 exp.Null, 447) 448 449 450def simplify_coalesce(expression): 451 # COALESCE(x) -> x 452 if ( 453 isinstance(expression, exp.Coalesce) 454 and not expression.expressions 455 # COALESCE is also used as a Spark partitioning hint 456 and not isinstance(expression.parent, exp.Hint) 457 ): 458 return expression.this 459 460 if not isinstance(expression, COMPARISONS): 461 return expression 462 463 if isinstance(expression.left, exp.Coalesce): 464 coalesce = expression.left 465 other = expression.right 466 elif isinstance(expression.right, exp.Coalesce): 467 coalesce = expression.right 468 other = expression.left 469 else: 470 return expression 471 472 # This transformation is valid for non-constants, 473 # but it really only does anything if they are both constants. 474 if not isinstance(other, CONSTANTS): 475 return expression 476 477 # Find the first constant arg 478 for arg_index, arg in enumerate(coalesce.expressions): 479 if isinstance(arg, CONSTANTS): 480 break 481 else: 482 return expression 483 484 coalesce.set("expressions", coalesce.expressions[:arg_index]) 485 486 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 487 # since we already remove COALESCE at the top of this function. 488 coalesce = coalesce if coalesce.expressions else coalesce.this 489 490 # This expression is more complex than when we started, but it will get simplified further 491 return exp.or_( 492 exp.and_( 493 coalesce.is_(exp.null()).not_(copy=False), 494 expression.copy(), 495 copy=False, 496 ), 497 exp.and_( 498 coalesce.is_(exp.null()), 499 type(expression)(this=arg.copy(), expression=other.copy()), 500 copy=False, 501 ), 502 copy=False, 503 ) 504 505 506CONCATS = (exp.Concat, exp.DPipe) 507SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 508 509 510def simplify_concat(expression): 511 """Reduces all groups that contain string literals by concatenating them.""" 512 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 513 return expression 514 515 new_args = [] 516 for is_string_group, group in itertools.groupby( 517 expression.expressions or expression.flatten(), lambda e: e.is_string 518 ): 519 if is_string_group: 520 new_args.append(exp.Literal.string("".join(string.name for string in group))) 521 else: 522 new_args.extend(group) 523 524 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 525 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 526 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 527 528 529# CROSS joins result in an empty table if the right table is empty. 530# So we can only simplify certain types of joins to CROSS. 531# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 532JOINS = { 533 ("", ""), 534 ("", "INNER"), 535 ("RIGHT", ""), 536 ("RIGHT", "OUTER"), 537} 538 539 540def remove_where_true(expression): 541 for where in expression.find_all(exp.Where): 542 if always_true(where.this): 543 where.parent.set("where", None) 544 for join in expression.find_all(exp.Join): 545 if ( 546 always_true(join.args.get("on")) 547 and not join.args.get("using") 548 and not join.args.get("method") 549 and (join.side, join.kind) in JOINS 550 ): 551 join.set("on", None) 552 join.set("side", None) 553 join.set("kind", "CROSS") 554 555 556def always_true(expression): 557 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 558 expression, exp.Literal 559 ) 560 561 562def is_complement(a, b): 563 return isinstance(b, exp.Not) and b.this == a 564 565 566def is_false(a: exp.Expression) -> bool: 567 return type(a) is exp.Boolean and not a.this 568 569 570def is_null(a: exp.Expression) -> bool: 571 return type(a) is exp.Null 572 573 574def eval_boolean(expression, a, b): 575 if isinstance(expression, (exp.EQ, exp.Is)): 576 return boolean_literal(a == b) 577 if isinstance(expression, exp.NEQ): 578 return boolean_literal(a != b) 579 if isinstance(expression, exp.GT): 580 return boolean_literal(a > b) 581 if isinstance(expression, exp.GTE): 582 return boolean_literal(a >= b) 583 if isinstance(expression, exp.LT): 584 return boolean_literal(a < b) 585 if isinstance(expression, exp.LTE): 586 return boolean_literal(a <= b) 587 return None 588 589 590def extract_date(cast): 591 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 592 # so in that case we can't extract the date. 593 try: 594 if cast.args["to"].this == exp.DataType.Type.DATE: 595 return datetime.date.fromisoformat(cast.name) 596 if cast.args["to"].this == exp.DataType.Type.DATETIME: 597 return datetime.datetime.fromisoformat(cast.name) 598 except ValueError: 599 return None 600 601 602def extract_interval(interval): 603 try: 604 from dateutil.relativedelta import relativedelta # type: ignore 605 except ModuleNotFoundError: 606 return None 607 608 n = int(interval.name) 609 unit = interval.text("unit").lower() 610 611 if unit == "year": 612 return relativedelta(years=n) 613 if unit == "month": 614 return relativedelta(months=n) 615 if unit == "week": 616 return relativedelta(weeks=n) 617 if unit == "day": 618 return relativedelta(days=n) 619 return None 620 621 622def date_literal(date): 623 return exp.cast( 624 exp.Literal.string(date), 625 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 626 ) 627 628 629def boolean_literal(condition): 630 return exp.true() if condition else exp.false() 631 632 633def _flat_simplify(expression, simplifier, root=True): 634 if root or not expression.same_parent: 635 operands = [] 636 queue = deque(expression.flatten(unnest=False)) 637 size = len(queue) 638 639 while queue: 640 a = queue.popleft() 641 642 for b in queue: 643 result = simplifier(expression, a, b) 644 645 if result: 646 queue.remove(b) 647 queue.appendleft(result) 648 break 649 else: 650 operands.append(a) 651 652 if len(operands) < size: 653 return functools.reduce( 654 lambda a, b: expression.__class__(this=a, expression=b), operands 655 ) 656 return expression
16def simplify(expression): 17 """ 18 Rewrite sqlglot AST to simplify expressions. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 23 >>> simplify(expression).sql() 24 'TRUE' 25 26 Args: 27 expression (sqlglot.Expression): expression to simplify 28 Returns: 29 sqlglot.Expression: simplified expression 30 """ 31 32 generate = cached_generator() 33 34 # group by expressions cannot be simplified, for example 35 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 36 # the projection must exactly match the group by key 37 for group in expression.find_all(exp.Group): 38 select = group.parent 39 groups = set(group.expressions) 40 group.meta[FINAL] = True 41 42 for e in select.selects: 43 for node, *_ in e.walk(): 44 if node in groups: 45 e.meta[FINAL] = True 46 break 47 48 having = select.args.get("having") 49 if having: 50 for node, *_ in having.walk(): 51 if node in groups: 52 having.meta[FINAL] = True 53 break 54 55 def _simplify(expression, root=True): 56 if expression.meta.get(FINAL): 57 return expression 58 59 # Pre-order transformations 60 node = expression 61 node = rewrite_between(node) 62 node = uniq_sort(node, generate, root) 63 node = absorb_and_eliminate(node, root) 64 node = simplify_concat(node) 65 66 exp.replace_children(node, lambda e: _simplify(e, False)) 67 68 # Post-order transformations 69 node = simplify_not(node) 70 node = flatten(node) 71 node = simplify_connectors(node, root) 72 node = remove_compliments(node, root) 73 node.parent = expression.parent 74 node = simplify_literals(node, root) 75 node = simplify_parens(node) 76 node = simplify_coalesce(node) 77 78 if root: 79 expression.replace(node) 80 81 return node 82 83 expression = while_changing(expression, _simplify) 84 remove_where_true(expression) 85 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
88def rewrite_between(expression: exp.Expression) -> exp.Expression: 89 """Rewrite x between y and z to x >= y AND x <= z. 90 91 This is done because comparison simplification is only done on lt/lte/gt/gte. 92 """ 93 if isinstance(expression, exp.Between): 94 return exp.and_( 95 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 96 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 97 copy=False, 98 ) 99 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
102def simplify_not(expression): 103 """ 104 Demorgan's Law 105 NOT (x OR y) -> NOT x AND NOT y 106 NOT (x AND y) -> NOT x OR NOT y 107 """ 108 if isinstance(expression, exp.Not): 109 if is_null(expression.this): 110 return exp.null() 111 if isinstance(expression.this, exp.Paren): 112 condition = expression.this.unnest() 113 if isinstance(condition, exp.And): 114 return exp.or_( 115 exp.not_(condition.left, copy=False), 116 exp.not_(condition.right, copy=False), 117 copy=False, 118 ) 119 if isinstance(condition, exp.Or): 120 return exp.and_( 121 exp.not_(condition.left, copy=False), 122 exp.not_(condition.right, copy=False), 123 copy=False, 124 ) 125 if is_null(condition): 126 return exp.null() 127 if always_true(expression.this): 128 return exp.false() 129 if is_false(expression.this): 130 return exp.true() 131 if isinstance(expression.this, exp.Not): 132 # double negation 133 # NOT NOT x -> x 134 return expression.this.this 135 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
138def flatten(expression): 139 """ 140 A AND (B AND C) -> A AND B AND C 141 A OR (B OR C) -> A OR B OR C 142 """ 143 if isinstance(expression, exp.Connector): 144 for node in expression.args.values(): 145 child = node.unnest() 146 if isinstance(child, expression.__class__): 147 node.replace(child) 148 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
151def simplify_connectors(expression, root=True): 152 def _simplify_connectors(expression, left, right): 153 if left == right: 154 return left 155 if isinstance(expression, exp.And): 156 if is_false(left) or is_false(right): 157 return exp.false() 158 if is_null(left) or is_null(right): 159 return exp.null() 160 if always_true(left) and always_true(right): 161 return exp.true() 162 if always_true(left): 163 return right 164 if always_true(right): 165 return left 166 return _simplify_comparison(expression, left, right) 167 elif isinstance(expression, exp.Or): 168 if always_true(left) or always_true(right): 169 return exp.true() 170 if is_false(left) and is_false(right): 171 return exp.false() 172 if ( 173 (is_null(left) and is_null(right)) 174 or (is_null(left) and is_false(right)) 175 or (is_false(left) and is_null(right)) 176 ): 177 return exp.null() 178 if is_false(left): 179 return right 180 if is_false(right): 181 return left 182 return _simplify_comparison(expression, left, right, or_=True) 183 184 if isinstance(expression, exp.Connector): 185 return _flat_simplify(expression, _simplify_connectors, root) 186 return expression
269def remove_compliments(expression, root=True): 270 """ 271 Removing compliments. 272 273 A AND NOT A -> FALSE 274 A OR NOT A -> TRUE 275 """ 276 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 277 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 278 279 for a, b in itertools.permutations(expression.flatten(), 2): 280 if is_complement(a, b): 281 return compliment 282 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
285def uniq_sort(expression, generate, root=True): 286 """ 287 Uniq and sort a connector. 288 289 C AND A AND B AND B -> A AND B AND C 290 """ 291 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 292 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 293 flattened = tuple(expression.flatten()) 294 deduped = {generate(e): e for e in flattened} 295 arr = tuple(deduped.items()) 296 297 # check if the operands are already sorted, if not sort them 298 # A AND C AND B -> A AND B AND C 299 for i, (sql, e) in enumerate(arr[1:]): 300 if sql < arr[i][0]: 301 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 302 break 303 else: 304 # we didn't have to sort but maybe we need to dedup 305 if len(deduped) < len(flattened): 306 expression = result_func(*deduped.values(), copy=False) 307 308 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
311def absorb_and_eliminate(expression, root=True): 312 """ 313 absorption: 314 A AND (A OR B) -> A 315 A OR (A AND B) -> A 316 A AND (NOT A OR B) -> A AND B 317 A OR (NOT A AND B) -> A OR B 318 elimination: 319 (A AND B) OR (A AND NOT B) -> A 320 (A OR B) AND (A OR NOT B) -> A 321 """ 322 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 323 kind = exp.Or if isinstance(expression, exp.And) else exp.And 324 325 for a, b in itertools.permutations(expression.flatten(), 2): 326 if isinstance(a, kind): 327 aa, ab = a.unnest_operands() 328 329 # absorb 330 if is_complement(b, aa): 331 aa.replace(exp.true() if kind == exp.And else exp.false()) 332 elif is_complement(b, ab): 333 ab.replace(exp.true() if kind == exp.And else exp.false()) 334 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 335 a.replace(exp.false() if kind == exp.And else exp.true()) 336 elif isinstance(b, kind): 337 # eliminate 338 rhs = b.unnest_operands() 339 ba, bb = rhs 340 341 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 342 a.replace(aa) 343 b.replace(aa) 344 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 345 a.replace(ab) 346 b.replace(ab) 347 348 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
351def simplify_literals(expression, root=True): 352 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 353 return _flat_simplify(expression, _simplify_binary, root) 354 elif isinstance(expression, exp.Neg): 355 this = expression.this 356 if this.is_number: 357 value = this.name 358 if value[0] == "-": 359 return exp.Literal.number(value[1:]) 360 return exp.Literal.number(f"-{value}") 361 362 return expression
425def simplify_parens(expression): 426 if not isinstance(expression, exp.Paren): 427 return expression 428 429 this = expression.this 430 parent = expression.parent 431 432 if not isinstance(this, exp.Select) and ( 433 not isinstance(parent, (exp.Condition, exp.Binary)) 434 or isinstance(this, exp.Predicate) 435 or not isinstance(this, exp.Binary) 436 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 437 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 438 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 439 ): 440 return expression.this 441 return expression
451def simplify_coalesce(expression): 452 # COALESCE(x) -> x 453 if ( 454 isinstance(expression, exp.Coalesce) 455 and not expression.expressions 456 # COALESCE is also used as a Spark partitioning hint 457 and not isinstance(expression.parent, exp.Hint) 458 ): 459 return expression.this 460 461 if not isinstance(expression, COMPARISONS): 462 return expression 463 464 if isinstance(expression.left, exp.Coalesce): 465 coalesce = expression.left 466 other = expression.right 467 elif isinstance(expression.right, exp.Coalesce): 468 coalesce = expression.right 469 other = expression.left 470 else: 471 return expression 472 473 # This transformation is valid for non-constants, 474 # but it really only does anything if they are both constants. 475 if not isinstance(other, CONSTANTS): 476 return expression 477 478 # Find the first constant arg 479 for arg_index, arg in enumerate(coalesce.expressions): 480 if isinstance(arg, CONSTANTS): 481 break 482 else: 483 return expression 484 485 coalesce.set("expressions", coalesce.expressions[:arg_index]) 486 487 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 488 # since we already remove COALESCE at the top of this function. 489 coalesce = coalesce if coalesce.expressions else coalesce.this 490 491 # This expression is more complex than when we started, but it will get simplified further 492 return exp.or_( 493 exp.and_( 494 coalesce.is_(exp.null()).not_(copy=False), 495 expression.copy(), 496 copy=False, 497 ), 498 exp.and_( 499 coalesce.is_(exp.null()), 500 type(expression)(this=arg.copy(), expression=other.copy()), 501 copy=False, 502 ), 503 copy=False, 504 )
511def simplify_concat(expression): 512 """Reduces all groups that contain string literals by concatenating them.""" 513 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 514 return expression 515 516 new_args = [] 517 for is_string_group, group in itertools.groupby( 518 expression.expressions or expression.flatten(), lambda e: e.is_string 519 ): 520 if is_string_group: 521 new_args.append(exp.Literal.string("".join(string.name for string in group))) 522 else: 523 new_args.extend(group) 524 525 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 526 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 527 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
541def remove_where_true(expression): 542 for where in expression.find_all(exp.Where): 543 if always_true(where.this): 544 where.parent.set("where", None) 545 for join in expression.find_all(exp.Join): 546 if ( 547 always_true(join.args.get("on")) 548 and not join.args.get("using") 549 and not join.args.get("method") 550 and (join.side, join.kind) in JOINS 551 ): 552 join.set("on", None) 553 join.set("side", None) 554 join.set("kind", "CROSS")
575def eval_boolean(expression, a, b): 576 if isinstance(expression, (exp.EQ, exp.Is)): 577 return boolean_literal(a == b) 578 if isinstance(expression, exp.NEQ): 579 return boolean_literal(a != b) 580 if isinstance(expression, exp.GT): 581 return boolean_literal(a > b) 582 if isinstance(expression, exp.GTE): 583 return boolean_literal(a >= b) 584 if isinstance(expression, exp.LT): 585 return boolean_literal(a < b) 586 if isinstance(expression, exp.LTE): 587 return boolean_literal(a <= b) 588 return None
591def extract_date(cast): 592 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 593 # so in that case we can't extract the date. 594 try: 595 if cast.args["to"].this == exp.DataType.Type.DATE: 596 return datetime.date.fromisoformat(cast.name) 597 if cast.args["to"].this == exp.DataType.Type.DATETIME: 598 return datetime.datetime.fromisoformat(cast.name) 599 except ValueError: 600 return None
603def extract_interval(interval): 604 try: 605 from dateutil.relativedelta import relativedelta # type: ignore 606 except ModuleNotFoundError: 607 return None 608 609 n = int(interval.name) 610 unit = interval.text("unit").lower() 611 612 if unit == "year": 613 return relativedelta(years=n) 614 if unit == "month": 615 return relativedelta(months=n) 616 if unit == "week": 617 return relativedelta(weeks=n) 618 if unit == "day": 619 return relativedelta(days=n) 620 return None