Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4import typing as t
  5from collections import deque
  6from decimal import Decimal
  7
  8from sqlglot import exp
  9from sqlglot.generator import cached_generator
 10from sqlglot.helper import first, merge_ranges, while_changing
 11
 12# Final means that an expression should not be simplified
 13FINAL = "final"
 14
 15
 16class UnsupportedUnit(Exception):
 17    pass
 18
 19
 20def simplify(expression):
 21    """
 22    Rewrite sqlglot AST to simplify expressions.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 27        >>> simplify(expression).sql()
 28        'TRUE'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression to simplify
 32    Returns:
 33        sqlglot.Expression: simplified expression
 34    """
 35
 36    generate = cached_generator()
 37
 38    # group by expressions cannot be simplified, for example
 39    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 40    # the projection must exactly match the group by key
 41    for group in expression.find_all(exp.Group):
 42        select = group.parent
 43        groups = set(group.expressions)
 44        group.meta[FINAL] = True
 45
 46        for e in select.selects:
 47            for node, *_ in e.walk():
 48                if node in groups:
 49                    e.meta[FINAL] = True
 50                    break
 51
 52        having = select.args.get("having")
 53        if having:
 54            for node, *_ in having.walk():
 55                if node in groups:
 56                    having.meta[FINAL] = True
 57                    break
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # Pre-order transformations
 64        node = expression
 65        node = rewrite_between(node)
 66        node = uniq_sort(node, generate, root)
 67        node = absorb_and_eliminate(node, root)
 68        node = simplify_concat(node)
 69
 70        exp.replace_children(node, lambda e: _simplify(e, False))
 71
 72        # Post-order transformations
 73        node = simplify_not(node)
 74        node = flatten(node)
 75        node = simplify_connectors(node, root)
 76        node = remove_compliments(node, root)
 77        node = simplify_coalesce(node)
 78        node.parent = expression.parent
 79        node = simplify_literals(node, root)
 80        node = simplify_equality(node)
 81        node = simplify_parens(node)
 82        node = simplify_datetrunc_predicate(node)
 83
 84        if root:
 85            expression.replace(node)
 86
 87        return node
 88
 89    expression = while_changing(expression, _simplify)
 90    remove_where_true(expression)
 91    return expression
 92
 93
 94def catch(*exceptions):
 95    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 96
 97    def decorator(func):
 98        def wrapped(expression, *args, **kwargs):
 99            try:
100                return func(expression, *args, **kwargs)
101            except exceptions:
102                return expression
103
104        return wrapped
105
106    return decorator
107
108
109def rewrite_between(expression: exp.Expression) -> exp.Expression:
110    """Rewrite x between y and z to x >= y AND x <= z.
111
112    This is done because comparison simplification is only done on lt/lte/gt/gte.
113    """
114    if isinstance(expression, exp.Between):
115        return exp.and_(
116            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
117            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
118            copy=False,
119        )
120    return expression
121
122
123def simplify_not(expression):
124    """
125    Demorgan's Law
126    NOT (x OR y) -> NOT x AND NOT y
127    NOT (x AND y) -> NOT x OR NOT y
128    """
129    if isinstance(expression, exp.Not):
130        if is_null(expression.this):
131            return exp.null()
132        if isinstance(expression.this, exp.Paren):
133            condition = expression.this.unnest()
134            if isinstance(condition, exp.And):
135                return exp.or_(
136                    exp.not_(condition.left, copy=False),
137                    exp.not_(condition.right, copy=False),
138                    copy=False,
139                )
140            if isinstance(condition, exp.Or):
141                return exp.and_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if is_null(condition):
147                return exp.null()
148        if always_true(expression.this):
149            return exp.false()
150        if is_false(expression.this):
151            return exp.true()
152        if isinstance(expression.this, exp.Not):
153            # double negation
154            # NOT NOT x -> x
155            return expression.this.this
156    return expression
157
158
159def flatten(expression):
160    """
161    A AND (B AND C) -> A AND B AND C
162    A OR (B OR C) -> A OR B OR C
163    """
164    if isinstance(expression, exp.Connector):
165        for node in expression.args.values():
166            child = node.unnest()
167            if isinstance(child, expression.__class__):
168                node.replace(child)
169    return expression
170
171
172def simplify_connectors(expression, root=True):
173    def _simplify_connectors(expression, left, right):
174        if left == right:
175            return left
176        if isinstance(expression, exp.And):
177            if is_false(left) or is_false(right):
178                return exp.false()
179            if is_null(left) or is_null(right):
180                return exp.null()
181            if always_true(left) and always_true(right):
182                return exp.true()
183            if always_true(left):
184                return right
185            if always_true(right):
186                return left
187            return _simplify_comparison(expression, left, right)
188        elif isinstance(expression, exp.Or):
189            if always_true(left) or always_true(right):
190                return exp.true()
191            if is_false(left) and is_false(right):
192                return exp.false()
193            if (
194                (is_null(left) and is_null(right))
195                or (is_null(left) and is_false(right))
196                or (is_false(left) and is_null(right))
197            ):
198                return exp.null()
199            if is_false(left):
200                return right
201            if is_false(right):
202                return left
203            return _simplify_comparison(expression, left, right, or_=True)
204
205    if isinstance(expression, exp.Connector):
206        return _flat_simplify(expression, _simplify_connectors, root)
207    return expression
208
209
210LT_LTE = (exp.LT, exp.LTE)
211GT_GTE = (exp.GT, exp.GTE)
212
213COMPARISONS = (
214    *LT_LTE,
215    *GT_GTE,
216    exp.EQ,
217    exp.NEQ,
218    exp.Is,
219)
220
221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
222    exp.LT: exp.GT,
223    exp.GT: exp.LT,
224    exp.LTE: exp.GTE,
225    exp.GTE: exp.LTE,
226}
227
228
229def _simplify_comparison(expression, left, right, or_=False):
230    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
231        ll, lr = left.args.values()
232        rl, rr = right.args.values()
233
234        largs = {ll, lr}
235        rargs = {rl, rr}
236
237        matching = largs & rargs
238        columns = {m for m in matching if isinstance(m, exp.Column)}
239
240        if matching and columns:
241            try:
242                l = first(largs - columns)
243                r = first(rargs - columns)
244            except StopIteration:
245                return expression
246
247            # make sure the comparison is always of the form x > 1 instead of 1 < x
248            if left.__class__ in INVERSE_COMPARISONS and l == ll:
249                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
250            if right.__class__ in INVERSE_COMPARISONS and r == rl:
251                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
252
253            if l.is_number and r.is_number:
254                l = float(l.name)
255                r = float(r.name)
256            elif l.is_string and r.is_string:
257                l = l.name
258                r = r.name
259            else:
260                return None
261
262            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
263                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
264                    return left if (av > bv if or_ else av <= bv) else right
265                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
266                    return left if (av < bv if or_ else av >= bv) else right
267
268                # we can't ever shortcut to true because the column could be null
269                if not or_:
270                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
271                        if av <= bv:
272                            return exp.false()
273                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
274                        if av >= bv:
275                            return exp.false()
276                    elif isinstance(a, exp.EQ):
277                        if isinstance(b, exp.LT):
278                            return exp.false() if av >= bv else a
279                        if isinstance(b, exp.LTE):
280                            return exp.false() if av > bv else a
281                        if isinstance(b, exp.GT):
282                            return exp.false() if av <= bv else a
283                        if isinstance(b, exp.GTE):
284                            return exp.false() if av < bv else a
285                        if isinstance(b, exp.NEQ):
286                            return exp.false() if av == bv else a
287    return None
288
289
290def remove_compliments(expression, root=True):
291    """
292    Removing compliments.
293
294    A AND NOT A -> FALSE
295    A OR NOT A -> TRUE
296    """
297    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
298        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
299
300        for a, b in itertools.permutations(expression.flatten(), 2):
301            if is_complement(a, b):
302                return compliment
303    return expression
304
305
306def uniq_sort(expression, generate, root=True):
307    """
308    Uniq and sort a connector.
309
310    C AND A AND B AND B -> A AND B AND C
311    """
312    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
313        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
314        flattened = tuple(expression.flatten())
315        deduped = {generate(e): e for e in flattened}
316        arr = tuple(deduped.items())
317
318        # check if the operands are already sorted, if not sort them
319        # A AND C AND B -> A AND B AND C
320        for i, (sql, e) in enumerate(arr[1:]):
321            if sql < arr[i][0]:
322                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
323                break
324        else:
325            # we didn't have to sort but maybe we need to dedup
326            if len(deduped) < len(flattened):
327                expression = result_func(*deduped.values(), copy=False)
328
329    return expression
330
331
332def absorb_and_eliminate(expression, root=True):
333    """
334    absorption:
335        A AND (A OR B) -> A
336        A OR (A AND B) -> A
337        A AND (NOT A OR B) -> A AND B
338        A OR (NOT A AND B) -> A OR B
339    elimination:
340        (A AND B) OR (A AND NOT B) -> A
341        (A OR B) AND (A OR NOT B) -> A
342    """
343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
344        kind = exp.Or if isinstance(expression, exp.And) else exp.And
345
346        for a, b in itertools.permutations(expression.flatten(), 2):
347            if isinstance(a, kind):
348                aa, ab = a.unnest_operands()
349
350                # absorb
351                if is_complement(b, aa):
352                    aa.replace(exp.true() if kind == exp.And else exp.false())
353                elif is_complement(b, ab):
354                    ab.replace(exp.true() if kind == exp.And else exp.false())
355                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
356                    a.replace(exp.false() if kind == exp.And else exp.true())
357                elif isinstance(b, kind):
358                    # eliminate
359                    rhs = b.unnest_operands()
360                    ba, bb = rhs
361
362                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
363                        a.replace(aa)
364                        b.replace(aa)
365                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
366                        a.replace(ab)
367                        b.replace(ab)
368
369    return expression
370
371
372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
373    exp.DateAdd: exp.Sub,
374    exp.DateSub: exp.Add,
375    exp.DatetimeAdd: exp.Sub,
376    exp.DatetimeSub: exp.Add,
377}
378
379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
380    **INVERSE_DATE_OPS,
381    exp.Add: exp.Sub,
382    exp.Sub: exp.Add,
383}
384
385
386def _is_number(expression: exp.Expression) -> bool:
387    return expression.is_number
388
389
390def _is_interval(expression: exp.Expression) -> bool:
391    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
392
393
394@catch(ModuleNotFoundError, UnsupportedUnit)
395def simplify_equality(expression: exp.Expression) -> exp.Expression:
396    """
397    Use the subtraction and addition properties of equality to simplify expressions:
398
399        x + 1 = 3 becomes x = 2
400
401    There are two binary operations in the above expression: + and =
402    Here's how we reference all the operands in the code below:
403
404          l     r
405        x + 1 = 3
406        a   b
407    """
408    if isinstance(expression, COMPARISONS):
409        l, r = expression.left, expression.right
410
411        if l.__class__ in INVERSE_OPS:
412            pass
413        elif r.__class__ in INVERSE_OPS:
414            l, r = r, l
415        else:
416            return expression
417
418        if r.is_number:
419            a_predicate = _is_number
420            b_predicate = _is_number
421        elif _is_date_literal(r):
422            a_predicate = _is_date_literal
423            b_predicate = _is_interval
424        else:
425            return expression
426
427        if l.__class__ in INVERSE_DATE_OPS:
428            a = l.this
429            b = exp.Interval(
430                this=l.expression.copy(),
431                unit=l.unit.copy(),
432            )
433        else:
434            a, b = l.left, l.right
435
436        if not a_predicate(a) and b_predicate(b):
437            pass
438        elif not a_predicate(b) and b_predicate(a):
439            a, b = b, a
440        else:
441            return expression
442
443        return expression.__class__(
444            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
445        )
446    return expression
447
448
449def simplify_literals(expression, root=True):
450    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
451        return _flat_simplify(expression, _simplify_binary, root)
452
453    if isinstance(expression, exp.Neg):
454        this = expression.this
455        if this.is_number:
456            value = this.name
457            if value[0] == "-":
458                return exp.Literal.number(value[1:])
459            return exp.Literal.number(f"-{value}")
460
461    return expression
462
463
464def _simplify_binary(expression, a, b):
465    if isinstance(expression, exp.Is):
466        if isinstance(b, exp.Not):
467            c = b.this
468            not_ = True
469        else:
470            c = b
471            not_ = False
472
473        if is_null(c):
474            if isinstance(a, exp.Literal):
475                return exp.true() if not_ else exp.false()
476            if is_null(a):
477                return exp.false() if not_ else exp.true()
478    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
479        return None
480    elif is_null(a) or is_null(b):
481        return exp.null()
482
483    if a.is_number and b.is_number:
484        a = int(a.name) if a.is_int else Decimal(a.name)
485        b = int(b.name) if b.is_int else Decimal(b.name)
486
487        if isinstance(expression, exp.Add):
488            return exp.Literal.number(a + b)
489        if isinstance(expression, exp.Sub):
490            return exp.Literal.number(a - b)
491        if isinstance(expression, exp.Mul):
492            return exp.Literal.number(a * b)
493        if isinstance(expression, exp.Div):
494            # engines have differing int div behavior so intdiv is not safe
495            if isinstance(a, int) and isinstance(b, int):
496                return None
497            return exp.Literal.number(a / b)
498
499        boolean = eval_boolean(expression, a, b)
500
501        if boolean:
502            return boolean
503    elif a.is_string and b.is_string:
504        boolean = eval_boolean(expression, a.this, b.this)
505
506        if boolean:
507            return boolean
508    elif _is_date_literal(a) and isinstance(b, exp.Interval):
509        a, b = extract_date(a), extract_interval(b)
510        if a and b:
511            if isinstance(expression, exp.Add):
512                return date_literal(a + b)
513            if isinstance(expression, exp.Sub):
514                return date_literal(a - b)
515    elif isinstance(a, exp.Interval) and _is_date_literal(b):
516        a, b = extract_interval(a), extract_date(b)
517        # you cannot subtract a date from an interval
518        if a and b and isinstance(expression, exp.Add):
519            return date_literal(a + b)
520
521    return None
522
523
524def simplify_parens(expression):
525    if not isinstance(expression, exp.Paren):
526        return expression
527
528    this = expression.this
529    parent = expression.parent
530
531    if not isinstance(this, exp.Select) and (
532        not isinstance(parent, (exp.Condition, exp.Binary))
533        or isinstance(parent, exp.Paren)
534        or not isinstance(this, exp.Binary)
535        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
536        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
537        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
538        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
539    ):
540        return this
541    return expression
542
543
544CONSTANTS = (
545    exp.Literal,
546    exp.Boolean,
547    exp.Null,
548)
549
550
551def simplify_coalesce(expression):
552    # COALESCE(x) -> x
553    if (
554        isinstance(expression, exp.Coalesce)
555        and not expression.expressions
556        # COALESCE is also used as a Spark partitioning hint
557        and not isinstance(expression.parent, exp.Hint)
558    ):
559        return expression.this
560
561    if not isinstance(expression, COMPARISONS):
562        return expression
563
564    if isinstance(expression.left, exp.Coalesce):
565        coalesce = expression.left
566        other = expression.right
567    elif isinstance(expression.right, exp.Coalesce):
568        coalesce = expression.right
569        other = expression.left
570    else:
571        return expression
572
573    # This transformation is valid for non-constants,
574    # but it really only does anything if they are both constants.
575    if not isinstance(other, CONSTANTS):
576        return expression
577
578    # Find the first constant arg
579    for arg_index, arg in enumerate(coalesce.expressions):
580        if isinstance(arg, CONSTANTS):
581            break
582    else:
583        return expression
584
585    coalesce.set("expressions", coalesce.expressions[:arg_index])
586
587    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
588    # since we already remove COALESCE at the top of this function.
589    coalesce = coalesce if coalesce.expressions else coalesce.this
590
591    # This expression is more complex than when we started, but it will get simplified further
592    return exp.paren(
593        exp.or_(
594            exp.and_(
595                coalesce.is_(exp.null()).not_(copy=False),
596                expression.copy(),
597                copy=False,
598            ),
599            exp.and_(
600                coalesce.is_(exp.null()),
601                type(expression)(this=arg.copy(), expression=other.copy()),
602                copy=False,
603            ),
604            copy=False,
605        )
606    )
607
608
609CONCATS = (exp.Concat, exp.DPipe)
610SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
611
612
613def simplify_concat(expression):
614    """Reduces all groups that contain string literals by concatenating them."""
615    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
616        return expression
617
618    new_args = []
619    for is_string_group, group in itertools.groupby(
620        expression.expressions or expression.flatten(), lambda e: e.is_string
621    ):
622        if is_string_group:
623            new_args.append(exp.Literal.string("".join(string.name for string in group)))
624        else:
625            new_args.extend(group)
626
627    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
628    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
629    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
630
631
632DateRange = t.Tuple[datetime.date, datetime.date]
633
634
635def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
636    """
637    Get the date range for a DATE_TRUNC equality comparison:
638
639    Example:
640        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
641    Returns:
642        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
643    """
644    floor = date_floor(date, unit)
645
646    if date != floor:
647        # This will always be False, except for NULL values.
648        return None
649
650    return floor, floor + interval(unit)
651
652
653def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
654    """Get the logical expression for a date range"""
655    return exp.and_(
656        left >= date_literal(drange[0]),
657        left < date_literal(drange[1]),
658        copy=False,
659    )
660
661
662def _datetrunc_eq(
663    left: exp.Expression, date: datetime.date, unit: str
664) -> t.Optional[exp.Expression]:
665    drange = _datetrunc_range(date, unit)
666    if not drange:
667        return None
668
669    return _datetrunc_eq_expression(left, drange)
670
671
672def _datetrunc_neq(
673    left: exp.Expression, date: datetime.date, unit: str
674) -> t.Optional[exp.Expression]:
675    drange = _datetrunc_range(date, unit)
676    if not drange:
677        return None
678
679    return exp.and_(
680        left < date_literal(drange[0]),
681        left >= date_literal(drange[1]),
682        copy=False,
683    )
684
685
686DateTruncBinaryTransform = t.Callable[
687    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
688]
689DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
690    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
691    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
692    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
693    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
694    exp.EQ: _datetrunc_eq,
695    exp.NEQ: _datetrunc_neq,
696}
697DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
698
699
700def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
701    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
702
703
704@catch(ModuleNotFoundError, UnsupportedUnit)
705def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
706    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
707    comparison = expression.__class__
708
709    if comparison not in DATETRUNC_COMPARISONS:
710        return expression
711
712    if isinstance(expression, exp.Binary):
713        l, r = expression.left, expression.right
714
715        if _is_datetrunc_predicate(l, r):
716            pass
717        elif _is_datetrunc_predicate(r, l):
718            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
719            l, r = r, l
720        else:
721            return expression
722
723        unit = l.unit.name.lower()
724        date = extract_date(r)
725
726        if not date:
727            return expression
728
729        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
730    elif isinstance(expression, exp.In):
731        l = expression.this
732        rs = expression.expressions
733
734        if all(_is_datetrunc_predicate(l, r) for r in rs):
735            unit = l.unit.name.lower()
736
737            ranges = []
738            for r in rs:
739                date = extract_date(r)
740                if not date:
741                    return expression
742                drange = _datetrunc_range(date, unit)
743                if drange:
744                    ranges.append(drange)
745
746            if not ranges:
747                return expression
748
749            ranges = merge_ranges(ranges)
750
751            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
752
753    return expression
754
755
756# CROSS joins result in an empty table if the right table is empty.
757# So we can only simplify certain types of joins to CROSS.
758# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
759JOINS = {
760    ("", ""),
761    ("", "INNER"),
762    ("RIGHT", ""),
763    ("RIGHT", "OUTER"),
764}
765
766
767def remove_where_true(expression):
768    for where in expression.find_all(exp.Where):
769        if always_true(where.this):
770            where.parent.set("where", None)
771    for join in expression.find_all(exp.Join):
772        if (
773            always_true(join.args.get("on"))
774            and not join.args.get("using")
775            and not join.args.get("method")
776            and (join.side, join.kind) in JOINS
777        ):
778            join.set("on", None)
779            join.set("side", None)
780            join.set("kind", "CROSS")
781
782
783def always_true(expression):
784    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
785        expression, exp.Literal
786    )
787
788
789def is_complement(a, b):
790    return isinstance(b, exp.Not) and b.this == a
791
792
793def is_false(a: exp.Expression) -> bool:
794    return type(a) is exp.Boolean and not a.this
795
796
797def is_null(a: exp.Expression) -> bool:
798    return type(a) is exp.Null
799
800
801def eval_boolean(expression, a, b):
802    if isinstance(expression, (exp.EQ, exp.Is)):
803        return boolean_literal(a == b)
804    if isinstance(expression, exp.NEQ):
805        return boolean_literal(a != b)
806    if isinstance(expression, exp.GT):
807        return boolean_literal(a > b)
808    if isinstance(expression, exp.GTE):
809        return boolean_literal(a >= b)
810    if isinstance(expression, exp.LT):
811        return boolean_literal(a < b)
812    if isinstance(expression, exp.LTE):
813        return boolean_literal(a <= b)
814    return None
815
816
817def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
818    if isinstance(value, datetime.datetime):
819        return value.date()
820    if isinstance(value, datetime.date):
821        return value
822    try:
823        return datetime.datetime.fromisoformat(value).date()
824    except ValueError:
825        return None
826
827
828def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
829    if isinstance(value, datetime.datetime):
830        return value
831    if isinstance(value, datetime.date):
832        return datetime.datetime(year=value.year, month=value.month, day=value.day)
833    try:
834        return datetime.datetime.fromisoformat(value)
835    except ValueError:
836        return None
837
838
839def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
840    if not value:
841        return None
842    if to.is_type(exp.DataType.Type.DATE):
843        return cast_as_date(value)
844    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
845        return cast_as_datetime(value)
846    return None
847
848
849def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
850    if isinstance(cast, exp.Cast):
851        to = cast.to
852    elif isinstance(cast, exp.TsOrDsToDate):
853        to = exp.DataType.build(exp.DataType.Type.DATE)
854    else:
855        return None
856
857    if isinstance(cast.this, exp.Literal):
858        value: t.Any = cast.this.name
859    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
860        value = extract_date(cast.this)
861    else:
862        return None
863    return cast_value(value, to)
864
865
866def _is_date_literal(expression: exp.Expression) -> bool:
867    return extract_date(expression) is not None
868
869
870def extract_interval(expression):
871    n = int(expression.name)
872    unit = expression.text("unit").lower()
873
874    try:
875        return interval(unit, n)
876    except (UnsupportedUnit, ModuleNotFoundError):
877        return None
878
879
880def date_literal(date):
881    return exp.cast(
882        exp.Literal.string(date),
883        exp.DataType.Type.DATETIME
884        if isinstance(date, datetime.datetime)
885        else exp.DataType.Type.DATE,
886    )
887
888
889def interval(unit: str, n: int = 1):
890    from dateutil.relativedelta import relativedelta
891
892    if unit == "year":
893        return relativedelta(years=1 * n)
894    if unit == "quarter":
895        return relativedelta(months=3 * n)
896    if unit == "month":
897        return relativedelta(months=1 * n)
898    if unit == "week":
899        return relativedelta(weeks=1 * n)
900    if unit == "day":
901        return relativedelta(days=1 * n)
902    if unit == "hour":
903        return relativedelta(hours=1 * n)
904    if unit == "minute":
905        return relativedelta(minutes=1 * n)
906    if unit == "second":
907        return relativedelta(seconds=1 * n)
908
909    raise UnsupportedUnit(f"Unsupported unit: {unit}")
910
911
912def date_floor(d: datetime.date, unit: str) -> datetime.date:
913    if unit == "year":
914        return d.replace(month=1, day=1)
915    if unit == "quarter":
916        if d.month <= 3:
917            return d.replace(month=1, day=1)
918        elif d.month <= 6:
919            return d.replace(month=4, day=1)
920        elif d.month <= 9:
921            return d.replace(month=7, day=1)
922        else:
923            return d.replace(month=10, day=1)
924    if unit == "month":
925        return d.replace(month=d.month, day=1)
926    if unit == "week":
927        # Assuming week starts on Monday (0) and ends on Sunday (6)
928        return d - datetime.timedelta(days=d.weekday())
929    if unit == "day":
930        return d
931
932    raise UnsupportedUnit(f"Unsupported unit: {unit}")
933
934
935def date_ceil(d: datetime.date, unit: str) -> datetime.date:
936    floor = date_floor(d, unit)
937
938    if floor == d:
939        return d
940
941    return floor + interval(unit)
942
943
944def boolean_literal(condition):
945    return exp.true() if condition else exp.false()
946
947
948def _flat_simplify(expression, simplifier, root=True):
949    if root or not expression.same_parent:
950        operands = []
951        queue = deque(expression.flatten(unnest=False))
952        size = len(queue)
953
954        while queue:
955            a = queue.popleft()
956
957            for b in queue:
958                result = simplifier(expression, a, b)
959
960                if result and result is not expression:
961                    queue.remove(b)
962                    queue.appendleft(result)
963                    break
964            else:
965                operands.append(a)
966
967        if len(operands) < size:
968            return functools.reduce(
969                lambda a, b: expression.__class__(this=a, expression=b), operands
970            )
971    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
17class UnsupportedUnit(Exception):
18    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression):
21def simplify(expression):
22    """
23    Rewrite sqlglot AST to simplify expressions.
24
25    Example:
26        >>> import sqlglot
27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
28        >>> simplify(expression).sql()
29        'TRUE'
30
31    Args:
32        expression (sqlglot.Expression): expression to simplify
33    Returns:
34        sqlglot.Expression: simplified expression
35    """
36
37    generate = cached_generator()
38
39    # group by expressions cannot be simplified, for example
40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
41    # the projection must exactly match the group by key
42    for group in expression.find_all(exp.Group):
43        select = group.parent
44        groups = set(group.expressions)
45        group.meta[FINAL] = True
46
47        for e in select.selects:
48            for node, *_ in e.walk():
49                if node in groups:
50                    e.meta[FINAL] = True
51                    break
52
53        having = select.args.get("having")
54        if having:
55            for node, *_ in having.walk():
56                if node in groups:
57                    having.meta[FINAL] = True
58                    break
59
60    def _simplify(expression, root=True):
61        if expression.meta.get(FINAL):
62            return expression
63
64        # Pre-order transformations
65        node = expression
66        node = rewrite_between(node)
67        node = uniq_sort(node, generate, root)
68        node = absorb_and_eliminate(node, root)
69        node = simplify_concat(node)
70
71        exp.replace_children(node, lambda e: _simplify(e, False))
72
73        # Post-order transformations
74        node = simplify_not(node)
75        node = flatten(node)
76        node = simplify_connectors(node, root)
77        node = remove_compliments(node, root)
78        node = simplify_coalesce(node)
79        node.parent = expression.parent
80        node = simplify_literals(node, root)
81        node = simplify_equality(node)
82        node = simplify_parens(node)
83        node = simplify_datetrunc_predicate(node)
84
85        if root:
86            expression.replace(node)
87
88        return node
89
90    expression = while_changing(expression, _simplify)
91    remove_where_true(expression)
92    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
 95def catch(*exceptions):
 96    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 97
 98    def decorator(func):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
104
105        return wrapped
106
107    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
110def rewrite_between(expression: exp.Expression) -> exp.Expression:
111    """Rewrite x between y and z to x >= y AND x <= z.
112
113    This is done because comparison simplification is only done on lt/lte/gt/gte.
114    """
115    if isinstance(expression, exp.Between):
116        return exp.and_(
117            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
118            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
119            copy=False,
120        )
121    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
124def simplify_not(expression):
125    """
126    Demorgan's Law
127    NOT (x OR y) -> NOT x AND NOT y
128    NOT (x AND y) -> NOT x OR NOT y
129    """
130    if isinstance(expression, exp.Not):
131        if is_null(expression.this):
132            return exp.null()
133        if isinstance(expression.this, exp.Paren):
134            condition = expression.this.unnest()
135            if isinstance(condition, exp.And):
136                return exp.or_(
137                    exp.not_(condition.left, copy=False),
138                    exp.not_(condition.right, copy=False),
139                    copy=False,
140                )
141            if isinstance(condition, exp.Or):
142                return exp.and_(
143                    exp.not_(condition.left, copy=False),
144                    exp.not_(condition.right, copy=False),
145                    copy=False,
146                )
147            if is_null(condition):
148                return exp.null()
149        if always_true(expression.this):
150            return exp.false()
151        if is_false(expression.this):
152            return exp.true()
153        if isinstance(expression.this, exp.Not):
154            # double negation
155            # NOT NOT x -> x
156            return expression.this.this
157    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
160def flatten(expression):
161    """
162    A AND (B AND C) -> A AND B AND C
163    A OR (B OR C) -> A OR B OR C
164    """
165    if isinstance(expression, exp.Connector):
166        for node in expression.args.values():
167            child = node.unnest()
168            if isinstance(child, expression.__class__):
169                node.replace(child)
170    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
173def simplify_connectors(expression, root=True):
174    def _simplify_connectors(expression, left, right):
175        if left == right:
176            return left
177        if isinstance(expression, exp.And):
178            if is_false(left) or is_false(right):
179                return exp.false()
180            if is_null(left) or is_null(right):
181                return exp.null()
182            if always_true(left) and always_true(right):
183                return exp.true()
184            if always_true(left):
185                return right
186            if always_true(right):
187                return left
188            return _simplify_comparison(expression, left, right)
189        elif isinstance(expression, exp.Or):
190            if always_true(left) or always_true(right):
191                return exp.true()
192            if is_false(left) and is_false(right):
193                return exp.false()
194            if (
195                (is_null(left) and is_null(right))
196                or (is_null(left) and is_false(right))
197                or (is_false(left) and is_null(right))
198            ):
199                return exp.null()
200            if is_false(left):
201                return right
202            if is_false(right):
203                return left
204            return _simplify_comparison(expression, left, right, or_=True)
205
206    if isinstance(expression, exp.Connector):
207        return _flat_simplify(expression, _simplify_connectors, root)
208    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_compliments(expression, root=True):
291def remove_compliments(expression, root=True):
292    """
293    Removing compliments.
294
295    A AND NOT A -> FALSE
296    A OR NOT A -> TRUE
297    """
298    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
299        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
300
301        for a, b in itertools.permutations(expression.flatten(), 2):
302            if is_complement(a, b):
303                return compliment
304    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
307def uniq_sort(expression, generate, root=True):
308    """
309    Uniq and sort a connector.
310
311    C AND A AND B AND B -> A AND B AND C
312    """
313    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
314        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
315        flattened = tuple(expression.flatten())
316        deduped = {generate(e): e for e in flattened}
317        arr = tuple(deduped.items())
318
319        # check if the operands are already sorted, if not sort them
320        # A AND C AND B -> A AND B AND C
321        for i, (sql, e) in enumerate(arr[1:]):
322            if sql < arr[i][0]:
323                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
324                break
325        else:
326            # we didn't have to sort but maybe we need to dedup
327            if len(deduped) < len(flattened):
328                expression = result_func(*deduped.values(), copy=False)
329
330    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
333def absorb_and_eliminate(expression, root=True):
334    """
335    absorption:
336        A AND (A OR B) -> A
337        A OR (A AND B) -> A
338        A AND (NOT A OR B) -> A AND B
339        A OR (NOT A AND B) -> A OR B
340    elimination:
341        (A AND B) OR (A AND NOT B) -> A
342        (A OR B) AND (A OR NOT B) -> A
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        kind = exp.Or if isinstance(expression, exp.And) else exp.And
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if isinstance(a, kind):
349                aa, ab = a.unnest_operands()
350
351                # absorb
352                if is_complement(b, aa):
353                    aa.replace(exp.true() if kind == exp.And else exp.false())
354                elif is_complement(b, ab):
355                    ab.replace(exp.true() if kind == exp.And else exp.false())
356                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
357                    a.replace(exp.false() if kind == exp.And else exp.true())
358                elif isinstance(b, kind):
359                    # eliminate
360                    rhs = b.unnest_operands()
361                    ba, bb = rhs
362
363                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
364                        a.replace(aa)
365                        b.replace(aa)
366                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
367                        a.replace(ab)
368                        b.replace(ab)
369
370    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_equality(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
450def simplify_literals(expression, root=True):
451    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
452        return _flat_simplify(expression, _simplify_binary, root)
453
454    if isinstance(expression, exp.Neg):
455        this = expression.this
456        if this.is_number:
457            value = this.name
458            if value[0] == "-":
459                return exp.Literal.number(value[1:])
460            return exp.Literal.number(f"-{value}")
461
462    return expression
def simplify_parens(expression):
525def simplify_parens(expression):
526    if not isinstance(expression, exp.Paren):
527        return expression
528
529    this = expression.this
530    parent = expression.parent
531
532    if not isinstance(this, exp.Select) and (
533        not isinstance(parent, (exp.Condition, exp.Binary))
534        or isinstance(parent, exp.Paren)
535        or not isinstance(this, exp.Binary)
536        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
537        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
538        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
539        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
540    ):
541        return this
542    return expression
def simplify_coalesce(expression):
552def simplify_coalesce(expression):
553    # COALESCE(x) -> x
554    if (
555        isinstance(expression, exp.Coalesce)
556        and not expression.expressions
557        # COALESCE is also used as a Spark partitioning hint
558        and not isinstance(expression.parent, exp.Hint)
559    ):
560        return expression.this
561
562    if not isinstance(expression, COMPARISONS):
563        return expression
564
565    if isinstance(expression.left, exp.Coalesce):
566        coalesce = expression.left
567        other = expression.right
568    elif isinstance(expression.right, exp.Coalesce):
569        coalesce = expression.right
570        other = expression.left
571    else:
572        return expression
573
574    # This transformation is valid for non-constants,
575    # but it really only does anything if they are both constants.
576    if not isinstance(other, CONSTANTS):
577        return expression
578
579    # Find the first constant arg
580    for arg_index, arg in enumerate(coalesce.expressions):
581        if isinstance(arg, CONSTANTS):
582            break
583    else:
584        return expression
585
586    coalesce.set("expressions", coalesce.expressions[:arg_index])
587
588    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
589    # since we already remove COALESCE at the top of this function.
590    coalesce = coalesce if coalesce.expressions else coalesce.this
591
592    # This expression is more complex than when we started, but it will get simplified further
593    return exp.paren(
594        exp.or_(
595            exp.and_(
596                coalesce.is_(exp.null()).not_(copy=False),
597                expression.copy(),
598                copy=False,
599            ),
600            exp.and_(
601                coalesce.is_(exp.null()),
602                type(expression)(this=arg.copy(), expression=other.copy()),
603                copy=False,
604            ),
605            copy=False,
606        )
607    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
614def simplify_concat(expression):
615    """Reduces all groups that contain string literals by concatenating them."""
616    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
617        return expression
618
619    new_args = []
620    for is_string_group, group in itertools.groupby(
621        expression.expressions or expression.flatten(), lambda e: e.is_string
622    ):
623        if is_string_group:
624            new_args.append(exp.Literal.string("".join(string.name for string in group)))
625        else:
626            new_args.extend(group)
627
628    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
629    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
630    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('', ''), ('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
768def remove_where_true(expression):
769    for where in expression.find_all(exp.Where):
770        if always_true(where.this):
771            where.parent.set("where", None)
772    for join in expression.find_all(exp.Join):
773        if (
774            always_true(join.args.get("on"))
775            and not join.args.get("using")
776            and not join.args.get("method")
777            and (join.side, join.kind) in JOINS
778        ):
779            join.set("on", None)
780            join.set("side", None)
781            join.set("kind", "CROSS")
def always_true(expression):
784def always_true(expression):
785    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
786        expression, exp.Literal
787    )
def is_complement(a, b):
790def is_complement(a, b):
791    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
794def is_false(a: exp.Expression) -> bool:
795    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
798def is_null(a: exp.Expression) -> bool:
799    return type(a) is exp.Null
def eval_boolean(expression, a, b):
802def eval_boolean(expression, a, b):
803    if isinstance(expression, (exp.EQ, exp.Is)):
804        return boolean_literal(a == b)
805    if isinstance(expression, exp.NEQ):
806        return boolean_literal(a != b)
807    if isinstance(expression, exp.GT):
808        return boolean_literal(a > b)
809    if isinstance(expression, exp.GTE):
810        return boolean_literal(a >= b)
811    if isinstance(expression, exp.LT):
812        return boolean_literal(a < b)
813    if isinstance(expression, exp.LTE):
814        return boolean_literal(a <= b)
815    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
818def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
819    if isinstance(value, datetime.datetime):
820        return value.date()
821    if isinstance(value, datetime.date):
822        return value
823    try:
824        return datetime.datetime.fromisoformat(value).date()
825    except ValueError:
826        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
829def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
830    if isinstance(value, datetime.datetime):
831        return value
832    if isinstance(value, datetime.date):
833        return datetime.datetime(year=value.year, month=value.month, day=value.day)
834    try:
835        return datetime.datetime.fromisoformat(value)
836    except ValueError:
837        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
840def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
841    if not value:
842        return None
843    if to.is_type(exp.DataType.Type.DATE):
844        return cast_as_date(value)
845    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
846        return cast_as_datetime(value)
847    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
850def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
851    if isinstance(cast, exp.Cast):
852        to = cast.to
853    elif isinstance(cast, exp.TsOrDsToDate):
854        to = exp.DataType.build(exp.DataType.Type.DATE)
855    else:
856        return None
857
858    if isinstance(cast.this, exp.Literal):
859        value: t.Any = cast.this.name
860    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
861        value = extract_date(cast.this)
862    else:
863        return None
864    return cast_value(value, to)
def extract_interval(expression):
871def extract_interval(expression):
872    n = int(expression.name)
873    unit = expression.text("unit").lower()
874
875    try:
876        return interval(unit, n)
877    except (UnsupportedUnit, ModuleNotFoundError):
878        return None
def date_literal(date):
881def date_literal(date):
882    return exp.cast(
883        exp.Literal.string(date),
884        exp.DataType.Type.DATETIME
885        if isinstance(date, datetime.datetime)
886        else exp.DataType.Type.DATE,
887    )
def interval(unit: str, n: int = 1):
890def interval(unit: str, n: int = 1):
891    from dateutil.relativedelta import relativedelta
892
893    if unit == "year":
894        return relativedelta(years=1 * n)
895    if unit == "quarter":
896        return relativedelta(months=3 * n)
897    if unit == "month":
898        return relativedelta(months=1 * n)
899    if unit == "week":
900        return relativedelta(weeks=1 * n)
901    if unit == "day":
902        return relativedelta(days=1 * n)
903    if unit == "hour":
904        return relativedelta(hours=1 * n)
905    if unit == "minute":
906        return relativedelta(minutes=1 * n)
907    if unit == "second":
908        return relativedelta(seconds=1 * n)
909
910    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
913def date_floor(d: datetime.date, unit: str) -> datetime.date:
914    if unit == "year":
915        return d.replace(month=1, day=1)
916    if unit == "quarter":
917        if d.month <= 3:
918            return d.replace(month=1, day=1)
919        elif d.month <= 6:
920            return d.replace(month=4, day=1)
921        elif d.month <= 9:
922            return d.replace(month=7, day=1)
923        else:
924            return d.replace(month=10, day=1)
925    if unit == "month":
926        return d.replace(month=d.month, day=1)
927    if unit == "week":
928        # Assuming week starts on Monday (0) and ends on Sunday (6)
929        return d - datetime.timedelta(days=d.weekday())
930    if unit == "day":
931        return d
932
933    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
936def date_ceil(d: datetime.date, unit: str) -> datetime.date:
937    floor = date_floor(d, unit)
938
939    if floor == d:
940        return d
941
942    return floor + interval(unit)
def boolean_literal(condition):
945def boolean_literal(condition):
946    return exp.true() if condition else exp.false()