Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4import typing as t
  5from collections import deque
  6from decimal import Decimal
  7
  8from sqlglot import exp
  9from sqlglot.generator import cached_generator
 10from sqlglot.helper import first, merge_ranges, while_changing
 11
 12# Final means that an expression should not be simplified
 13FINAL = "final"
 14
 15
 16class UnsupportedUnit(Exception):
 17    pass
 18
 19
 20def simplify(expression):
 21    """
 22    Rewrite sqlglot AST to simplify expressions.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 27        >>> simplify(expression).sql()
 28        'TRUE'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression to simplify
 32    Returns:
 33        sqlglot.Expression: simplified expression
 34    """
 35
 36    generate = cached_generator()
 37
 38    # group by expressions cannot be simplified, for example
 39    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 40    # the projection must exactly match the group by key
 41    for group in expression.find_all(exp.Group):
 42        select = group.parent
 43        groups = set(group.expressions)
 44        group.meta[FINAL] = True
 45
 46        for e in select.selects:
 47            for node, *_ in e.walk():
 48                if node in groups:
 49                    e.meta[FINAL] = True
 50                    break
 51
 52        having = select.args.get("having")
 53        if having:
 54            for node, *_ in having.walk():
 55                if node in groups:
 56                    having.meta[FINAL] = True
 57                    break
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # Pre-order transformations
 64        node = expression
 65        node = rewrite_between(node)
 66        node = uniq_sort(node, generate, root)
 67        node = absorb_and_eliminate(node, root)
 68        node = simplify_concat(node)
 69
 70        exp.replace_children(node, lambda e: _simplify(e, False))
 71
 72        # Post-order transformations
 73        node = simplify_not(node)
 74        node = flatten(node)
 75        node = simplify_connectors(node, root)
 76        node = remove_compliments(node, root)
 77        node = simplify_coalesce(node)
 78        node.parent = expression.parent
 79        node = simplify_literals(node, root)
 80        node = simplify_equality(node)
 81        node = simplify_parens(node)
 82        node = simplify_datetrunc_predicate(node)
 83
 84        if root:
 85            expression.replace(node)
 86
 87        return node
 88
 89    expression = while_changing(expression, _simplify)
 90    remove_where_true(expression)
 91    return expression
 92
 93
 94def catch(*exceptions):
 95    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 96
 97    def decorator(func):
 98        def wrapped(expression, *args, **kwargs):
 99            try:
100                return func(expression, *args, **kwargs)
101            except exceptions:
102                return expression
103
104        return wrapped
105
106    return decorator
107
108
109def rewrite_between(expression: exp.Expression) -> exp.Expression:
110    """Rewrite x between y and z to x >= y AND x <= z.
111
112    This is done because comparison simplification is only done on lt/lte/gt/gte.
113    """
114    if isinstance(expression, exp.Between):
115        return exp.and_(
116            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
117            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
118            copy=False,
119        )
120    return expression
121
122
123def simplify_not(expression):
124    """
125    Demorgan's Law
126    NOT (x OR y) -> NOT x AND NOT y
127    NOT (x AND y) -> NOT x OR NOT y
128    """
129    if isinstance(expression, exp.Not):
130        if is_null(expression.this):
131            return exp.null()
132        if isinstance(expression.this, exp.Paren):
133            condition = expression.this.unnest()
134            if isinstance(condition, exp.And):
135                return exp.or_(
136                    exp.not_(condition.left, copy=False),
137                    exp.not_(condition.right, copy=False),
138                    copy=False,
139                )
140            if isinstance(condition, exp.Or):
141                return exp.and_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if is_null(condition):
147                return exp.null()
148        if always_true(expression.this):
149            return exp.false()
150        if is_false(expression.this):
151            return exp.true()
152        if isinstance(expression.this, exp.Not):
153            # double negation
154            # NOT NOT x -> x
155            return expression.this.this
156    return expression
157
158
159def flatten(expression):
160    """
161    A AND (B AND C) -> A AND B AND C
162    A OR (B OR C) -> A OR B OR C
163    """
164    if isinstance(expression, exp.Connector):
165        for node in expression.args.values():
166            child = node.unnest()
167            if isinstance(child, expression.__class__):
168                node.replace(child)
169    return expression
170
171
172def simplify_connectors(expression, root=True):
173    def _simplify_connectors(expression, left, right):
174        if left == right:
175            return left
176        if isinstance(expression, exp.And):
177            if is_false(left) or is_false(right):
178                return exp.false()
179            if is_null(left) or is_null(right):
180                return exp.null()
181            if always_true(left) and always_true(right):
182                return exp.true()
183            if always_true(left):
184                return right
185            if always_true(right):
186                return left
187            return _simplify_comparison(expression, left, right)
188        elif isinstance(expression, exp.Or):
189            if always_true(left) or always_true(right):
190                return exp.true()
191            if is_false(left) and is_false(right):
192                return exp.false()
193            if (
194                (is_null(left) and is_null(right))
195                or (is_null(left) and is_false(right))
196                or (is_false(left) and is_null(right))
197            ):
198                return exp.null()
199            if is_false(left):
200                return right
201            if is_false(right):
202                return left
203            return _simplify_comparison(expression, left, right, or_=True)
204
205    if isinstance(expression, exp.Connector):
206        return _flat_simplify(expression, _simplify_connectors, root)
207    return expression
208
209
210LT_LTE = (exp.LT, exp.LTE)
211GT_GTE = (exp.GT, exp.GTE)
212
213COMPARISONS = (
214    *LT_LTE,
215    *GT_GTE,
216    exp.EQ,
217    exp.NEQ,
218    exp.Is,
219)
220
221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
222    exp.LT: exp.GT,
223    exp.GT: exp.LT,
224    exp.LTE: exp.GTE,
225    exp.GTE: exp.LTE,
226}
227
228
229def _simplify_comparison(expression, left, right, or_=False):
230    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
231        ll, lr = left.args.values()
232        rl, rr = right.args.values()
233
234        largs = {ll, lr}
235        rargs = {rl, rr}
236
237        matching = largs & rargs
238        columns = {m for m in matching if isinstance(m, exp.Column)}
239
240        if matching and columns:
241            try:
242                l = first(largs - columns)
243                r = first(rargs - columns)
244            except StopIteration:
245                return expression
246
247            # make sure the comparison is always of the form x > 1 instead of 1 < x
248            if left.__class__ in INVERSE_COMPARISONS and l == ll:
249                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
250            if right.__class__ in INVERSE_COMPARISONS and r == rl:
251                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
252
253            if l.is_number and r.is_number:
254                l = float(l.name)
255                r = float(r.name)
256            elif l.is_string and r.is_string:
257                l = l.name
258                r = r.name
259            else:
260                return None
261
262            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
263                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
264                    return left if (av > bv if or_ else av <= bv) else right
265                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
266                    return left if (av < bv if or_ else av >= bv) else right
267
268                # we can't ever shortcut to true because the column could be null
269                if not or_:
270                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
271                        if av <= bv:
272                            return exp.false()
273                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
274                        if av >= bv:
275                            return exp.false()
276                    elif isinstance(a, exp.EQ):
277                        if isinstance(b, exp.LT):
278                            return exp.false() if av >= bv else a
279                        if isinstance(b, exp.LTE):
280                            return exp.false() if av > bv else a
281                        if isinstance(b, exp.GT):
282                            return exp.false() if av <= bv else a
283                        if isinstance(b, exp.GTE):
284                            return exp.false() if av < bv else a
285                        if isinstance(b, exp.NEQ):
286                            return exp.false() if av == bv else a
287    return None
288
289
290def remove_compliments(expression, root=True):
291    """
292    Removing compliments.
293
294    A AND NOT A -> FALSE
295    A OR NOT A -> TRUE
296    """
297    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
298        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
299
300        for a, b in itertools.permutations(expression.flatten(), 2):
301            if is_complement(a, b):
302                return compliment
303    return expression
304
305
306def uniq_sort(expression, generate, root=True):
307    """
308    Uniq and sort a connector.
309
310    C AND A AND B AND B -> A AND B AND C
311    """
312    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
313        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
314        flattened = tuple(expression.flatten())
315        deduped = {generate(e): e for e in flattened}
316        arr = tuple(deduped.items())
317
318        # check if the operands are already sorted, if not sort them
319        # A AND C AND B -> A AND B AND C
320        for i, (sql, e) in enumerate(arr[1:]):
321            if sql < arr[i][0]:
322                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
323                break
324        else:
325            # we didn't have to sort but maybe we need to dedup
326            if len(deduped) < len(flattened):
327                expression = result_func(*deduped.values(), copy=False)
328
329    return expression
330
331
332def absorb_and_eliminate(expression, root=True):
333    """
334    absorption:
335        A AND (A OR B) -> A
336        A OR (A AND B) -> A
337        A AND (NOT A OR B) -> A AND B
338        A OR (NOT A AND B) -> A OR B
339    elimination:
340        (A AND B) OR (A AND NOT B) -> A
341        (A OR B) AND (A OR NOT B) -> A
342    """
343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
344        kind = exp.Or if isinstance(expression, exp.And) else exp.And
345
346        for a, b in itertools.permutations(expression.flatten(), 2):
347            if isinstance(a, kind):
348                aa, ab = a.unnest_operands()
349
350                # absorb
351                if is_complement(b, aa):
352                    aa.replace(exp.true() if kind == exp.And else exp.false())
353                elif is_complement(b, ab):
354                    ab.replace(exp.true() if kind == exp.And else exp.false())
355                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
356                    a.replace(exp.false() if kind == exp.And else exp.true())
357                elif isinstance(b, kind):
358                    # eliminate
359                    rhs = b.unnest_operands()
360                    ba, bb = rhs
361
362                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
363                        a.replace(aa)
364                        b.replace(aa)
365                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
366                        a.replace(ab)
367                        b.replace(ab)
368
369    return expression
370
371
372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
373    exp.DateAdd: exp.Sub,
374    exp.DateSub: exp.Add,
375    exp.DatetimeAdd: exp.Sub,
376    exp.DatetimeSub: exp.Add,
377}
378
379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
380    **INVERSE_DATE_OPS,
381    exp.Add: exp.Sub,
382    exp.Sub: exp.Add,
383}
384
385
386def _is_number(expression: exp.Expression) -> bool:
387    return expression.is_number
388
389
390def _is_date(expression: exp.Expression) -> bool:
391    return isinstance(expression, exp.Cast) and extract_date(expression) is not None
392
393
394def _is_interval(expression: exp.Expression) -> bool:
395    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
396
397
398@catch(ModuleNotFoundError, UnsupportedUnit)
399def simplify_equality(expression: exp.Expression) -> exp.Expression:
400    """
401    Use the subtraction and addition properties of equality to simplify expressions:
402
403        x + 1 = 3 becomes x = 2
404
405    There are two binary operations in the above expression: + and =
406    Here's how we reference all the operands in the code below:
407
408          l     r
409        x + 1 = 3
410        a   b
411    """
412    if isinstance(expression, COMPARISONS):
413        l, r = expression.left, expression.right
414
415        if l.__class__ in INVERSE_OPS:
416            pass
417        elif r.__class__ in INVERSE_OPS:
418            l, r = r, l
419        else:
420            return expression
421
422        if r.is_number:
423            a_predicate = _is_number
424            b_predicate = _is_number
425        elif _is_date(r):
426            a_predicate = _is_date
427            b_predicate = _is_interval
428        else:
429            return expression
430
431        if l.__class__ in INVERSE_DATE_OPS:
432            a = l.this
433            b = exp.Interval(
434                this=l.expression.copy(),
435                unit=l.unit.copy(),
436            )
437        else:
438            a, b = l.left, l.right
439
440        if not a_predicate(a) and b_predicate(b):
441            pass
442        elif not a_predicate(b) and b_predicate(a):
443            a, b = b, a
444        else:
445            return expression
446
447        return expression.__class__(
448            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
449        )
450    return expression
451
452
453def simplify_literals(expression, root=True):
454    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
455        return _flat_simplify(expression, _simplify_binary, root)
456
457    if isinstance(expression, exp.Neg):
458        this = expression.this
459        if this.is_number:
460            value = this.name
461            if value[0] == "-":
462                return exp.Literal.number(value[1:])
463            return exp.Literal.number(f"-{value}")
464
465    return expression
466
467
468def _simplify_binary(expression, a, b):
469    if isinstance(expression, exp.Is):
470        if isinstance(b, exp.Not):
471            c = b.this
472            not_ = True
473        else:
474            c = b
475            not_ = False
476
477        if is_null(c):
478            if isinstance(a, exp.Literal):
479                return exp.true() if not_ else exp.false()
480            if is_null(a):
481                return exp.false() if not_ else exp.true()
482    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
483        return None
484    elif is_null(a) or is_null(b):
485        return exp.null()
486
487    if a.is_number and b.is_number:
488        a = int(a.name) if a.is_int else Decimal(a.name)
489        b = int(b.name) if b.is_int else Decimal(b.name)
490
491        if isinstance(expression, exp.Add):
492            return exp.Literal.number(a + b)
493        if isinstance(expression, exp.Sub):
494            return exp.Literal.number(a - b)
495        if isinstance(expression, exp.Mul):
496            return exp.Literal.number(a * b)
497        if isinstance(expression, exp.Div):
498            # engines have differing int div behavior so intdiv is not safe
499            if isinstance(a, int) and isinstance(b, int):
500                return None
501            return exp.Literal.number(a / b)
502
503        boolean = eval_boolean(expression, a, b)
504
505        if boolean:
506            return boolean
507    elif a.is_string and b.is_string:
508        boolean = eval_boolean(expression, a.this, b.this)
509
510        if boolean:
511            return boolean
512    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
513        a, b = extract_date(a), extract_interval(b)
514        if a and b:
515            if isinstance(expression, exp.Add):
516                return date_literal(a + b)
517            if isinstance(expression, exp.Sub):
518                return date_literal(a - b)
519    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
520        a, b = extract_interval(a), extract_date(b)
521        # you cannot subtract a date from an interval
522        if a and b and isinstance(expression, exp.Add):
523            return date_literal(a + b)
524
525    return None
526
527
528def simplify_parens(expression):
529    if not isinstance(expression, exp.Paren):
530        return expression
531
532    this = expression.this
533    parent = expression.parent
534
535    if not isinstance(this, exp.Select) and (
536        not isinstance(parent, (exp.Condition, exp.Binary))
537        or isinstance(parent, exp.Paren)
538        or not isinstance(this, exp.Binary)
539        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
540        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
541        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
542        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
543    ):
544        return this
545    return expression
546
547
548CONSTANTS = (
549    exp.Literal,
550    exp.Boolean,
551    exp.Null,
552)
553
554
555def simplify_coalesce(expression):
556    # COALESCE(x) -> x
557    if (
558        isinstance(expression, exp.Coalesce)
559        and not expression.expressions
560        # COALESCE is also used as a Spark partitioning hint
561        and not isinstance(expression.parent, exp.Hint)
562    ):
563        return expression.this
564
565    if not isinstance(expression, COMPARISONS):
566        return expression
567
568    if isinstance(expression.left, exp.Coalesce):
569        coalesce = expression.left
570        other = expression.right
571    elif isinstance(expression.right, exp.Coalesce):
572        coalesce = expression.right
573        other = expression.left
574    else:
575        return expression
576
577    # This transformation is valid for non-constants,
578    # but it really only does anything if they are both constants.
579    if not isinstance(other, CONSTANTS):
580        return expression
581
582    # Find the first constant arg
583    for arg_index, arg in enumerate(coalesce.expressions):
584        if isinstance(arg, CONSTANTS):
585            break
586    else:
587        return expression
588
589    coalesce.set("expressions", coalesce.expressions[:arg_index])
590
591    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
592    # since we already remove COALESCE at the top of this function.
593    coalesce = coalesce if coalesce.expressions else coalesce.this
594
595    # This expression is more complex than when we started, but it will get simplified further
596    return exp.paren(
597        exp.or_(
598            exp.and_(
599                coalesce.is_(exp.null()).not_(copy=False),
600                expression.copy(),
601                copy=False,
602            ),
603            exp.and_(
604                coalesce.is_(exp.null()),
605                type(expression)(this=arg.copy(), expression=other.copy()),
606                copy=False,
607            ),
608            copy=False,
609        )
610    )
611
612
613CONCATS = (exp.Concat, exp.DPipe)
614SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
615
616
617def simplify_concat(expression):
618    """Reduces all groups that contain string literals by concatenating them."""
619    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
620        return expression
621
622    new_args = []
623    for is_string_group, group in itertools.groupby(
624        expression.expressions or expression.flatten(), lambda e: e.is_string
625    ):
626        if is_string_group:
627            new_args.append(exp.Literal.string("".join(string.name for string in group)))
628        else:
629            new_args.extend(group)
630
631    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
632    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
633    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
634
635
636DateRange = t.Tuple[datetime.date, datetime.date]
637
638
639def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
640    """
641    Get the date range for a DATE_TRUNC equality comparison:
642
643    Example:
644        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
645    Returns:
646        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
647    """
648    floor = date_floor(date, unit)
649
650    if date != floor:
651        # This will always be False, except for NULL values.
652        return None
653
654    return floor, floor + interval(unit)
655
656
657def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
658    """Get the logical expression for a date range"""
659    return exp.and_(
660        left >= date_literal(drange[0]),
661        left < date_literal(drange[1]),
662        copy=False,
663    )
664
665
666def _datetrunc_eq(
667    left: exp.Expression, date: datetime.date, unit: str
668) -> t.Optional[exp.Expression]:
669    drange = _datetrunc_range(date, unit)
670    if not drange:
671        return None
672
673    return _datetrunc_eq_expression(left, drange)
674
675
676def _datetrunc_neq(
677    left: exp.Expression, date: datetime.date, unit: str
678) -> t.Optional[exp.Expression]:
679    drange = _datetrunc_range(date, unit)
680    if not drange:
681        return None
682
683    return exp.and_(
684        left < date_literal(drange[0]),
685        left >= date_literal(drange[1]),
686        copy=False,
687    )
688
689
690DateTruncBinaryTransform = t.Callable[
691    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
692]
693DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
694    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
695    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
696    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
697    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
698    exp.EQ: _datetrunc_eq,
699    exp.NEQ: _datetrunc_neq,
700}
701DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
702
703
704def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
705    return (
706        isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
707        and isinstance(right, exp.Cast)
708        and right.is_type(*exp.DataType.TEMPORAL_TYPES)
709    )
710
711
712@catch(ModuleNotFoundError, UnsupportedUnit)
713def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
714    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
715    comparison = expression.__class__
716
717    if comparison not in DATETRUNC_COMPARISONS:
718        return expression
719
720    if isinstance(expression, exp.Binary):
721        l, r = expression.left, expression.right
722
723        if _is_datetrunc_predicate(l, r):
724            pass
725        elif _is_datetrunc_predicate(r, l):
726            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
727            l, r = r, l
728        else:
729            return expression
730
731        unit = l.unit.name.lower()
732        date = extract_date(r)
733
734        if not date:
735            return expression
736
737        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
738    elif isinstance(expression, exp.In):
739        l = expression.this
740        rs = expression.expressions
741
742        if all(_is_datetrunc_predicate(l, r) for r in rs):
743            unit = l.unit.name.lower()
744
745            ranges = []
746            for r in rs:
747                date = extract_date(r)
748                if not date:
749                    return expression
750                drange = _datetrunc_range(date, unit)
751                if drange:
752                    ranges.append(drange)
753
754            if not ranges:
755                return expression
756
757            ranges = merge_ranges(ranges)
758
759            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
760
761    return expression
762
763
764# CROSS joins result in an empty table if the right table is empty.
765# So we can only simplify certain types of joins to CROSS.
766# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
767JOINS = {
768    ("", ""),
769    ("", "INNER"),
770    ("RIGHT", ""),
771    ("RIGHT", "OUTER"),
772}
773
774
775def remove_where_true(expression):
776    for where in expression.find_all(exp.Where):
777        if always_true(where.this):
778            where.parent.set("where", None)
779    for join in expression.find_all(exp.Join):
780        if (
781            always_true(join.args.get("on"))
782            and not join.args.get("using")
783            and not join.args.get("method")
784            and (join.side, join.kind) in JOINS
785        ):
786            join.set("on", None)
787            join.set("side", None)
788            join.set("kind", "CROSS")
789
790
791def always_true(expression):
792    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
793        expression, exp.Literal
794    )
795
796
797def is_complement(a, b):
798    return isinstance(b, exp.Not) and b.this == a
799
800
801def is_false(a: exp.Expression) -> bool:
802    return type(a) is exp.Boolean and not a.this
803
804
805def is_null(a: exp.Expression) -> bool:
806    return type(a) is exp.Null
807
808
809def eval_boolean(expression, a, b):
810    if isinstance(expression, (exp.EQ, exp.Is)):
811        return boolean_literal(a == b)
812    if isinstance(expression, exp.NEQ):
813        return boolean_literal(a != b)
814    if isinstance(expression, exp.GT):
815        return boolean_literal(a > b)
816    if isinstance(expression, exp.GTE):
817        return boolean_literal(a >= b)
818    if isinstance(expression, exp.LT):
819        return boolean_literal(a < b)
820    if isinstance(expression, exp.LTE):
821        return boolean_literal(a <= b)
822    return None
823
824
825def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
826    if isinstance(value, datetime.datetime):
827        return value.date()
828    if isinstance(value, datetime.date):
829        return value
830    try:
831        return datetime.datetime.fromisoformat(value).date()
832    except ValueError:
833        return None
834
835
836def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
837    if isinstance(value, datetime.datetime):
838        return value
839    if isinstance(value, datetime.date):
840        return datetime.datetime(year=value.year, month=value.month, day=value.day)
841    try:
842        return datetime.datetime.fromisoformat(value)
843    except ValueError:
844        return None
845
846
847def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
848    if not value:
849        return None
850    if to.is_type(exp.DataType.Type.DATE):
851        return cast_as_date(value)
852    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
853        return cast_as_datetime(value)
854    return None
855
856
857def extract_date(cast: exp.Cast) -> t.Optional[t.Union[datetime.date, datetime.date]]:
858    value: t.Any
859    if isinstance(cast.this, exp.Literal):
860        value = cast.this.name
861    elif isinstance(cast.this, exp.Cast):
862        value = extract_date(cast.this)
863    else:
864        return None
865    return cast_value(value, cast.to)
866
867
868def extract_interval(expression):
869    n = int(expression.name)
870    unit = expression.text("unit").lower()
871
872    try:
873        return interval(unit, n)
874    except (UnsupportedUnit, ModuleNotFoundError):
875        return None
876
877
878def date_literal(date):
879    return exp.cast(
880        exp.Literal.string(date),
881        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
882    )
883
884
885def interval(unit: str, n: int = 1):
886    from dateutil.relativedelta import relativedelta
887
888    if unit == "year":
889        return relativedelta(years=1 * n)
890    if unit == "quarter":
891        return relativedelta(months=3 * n)
892    if unit == "month":
893        return relativedelta(months=1 * n)
894    if unit == "week":
895        return relativedelta(weeks=1 * n)
896    if unit == "day":
897        return relativedelta(days=1 * n)
898    if unit == "hour":
899        return relativedelta(hours=1 * n)
900    if unit == "minute":
901        return relativedelta(minutes=1 * n)
902    if unit == "second":
903        return relativedelta(seconds=1 * n)
904
905    raise UnsupportedUnit(f"Unsupported unit: {unit}")
906
907
908def date_floor(d: datetime.date, unit: str) -> datetime.date:
909    if unit == "year":
910        return d.replace(month=1, day=1)
911    if unit == "quarter":
912        if d.month <= 3:
913            return d.replace(month=1, day=1)
914        elif d.month <= 6:
915            return d.replace(month=4, day=1)
916        elif d.month <= 9:
917            return d.replace(month=7, day=1)
918        else:
919            return d.replace(month=10, day=1)
920    if unit == "month":
921        return d.replace(month=d.month, day=1)
922    if unit == "week":
923        # Assuming week starts on Monday (0) and ends on Sunday (6)
924        return d - datetime.timedelta(days=d.weekday())
925    if unit == "day":
926        return d
927
928    raise UnsupportedUnit(f"Unsupported unit: {unit}")
929
930
931def date_ceil(d: datetime.date, unit: str) -> datetime.date:
932    floor = date_floor(d, unit)
933
934    if floor == d:
935        return d
936
937    return floor + interval(unit)
938
939
940def boolean_literal(condition):
941    return exp.true() if condition else exp.false()
942
943
944def _flat_simplify(expression, simplifier, root=True):
945    if root or not expression.same_parent:
946        operands = []
947        queue = deque(expression.flatten(unnest=False))
948        size = len(queue)
949
950        while queue:
951            a = queue.popleft()
952
953            for b in queue:
954                result = simplifier(expression, a, b)
955
956                if result and result is not expression:
957                    queue.remove(b)
958                    queue.appendleft(result)
959                    break
960            else:
961                operands.append(a)
962
963        if len(operands) < size:
964            return functools.reduce(
965                lambda a, b: expression.__class__(this=a, expression=b), operands
966            )
967    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
17class UnsupportedUnit(Exception):
18    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression):
21def simplify(expression):
22    """
23    Rewrite sqlglot AST to simplify expressions.
24
25    Example:
26        >>> import sqlglot
27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
28        >>> simplify(expression).sql()
29        'TRUE'
30
31    Args:
32        expression (sqlglot.Expression): expression to simplify
33    Returns:
34        sqlglot.Expression: simplified expression
35    """
36
37    generate = cached_generator()
38
39    # group by expressions cannot be simplified, for example
40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
41    # the projection must exactly match the group by key
42    for group in expression.find_all(exp.Group):
43        select = group.parent
44        groups = set(group.expressions)
45        group.meta[FINAL] = True
46
47        for e in select.selects:
48            for node, *_ in e.walk():
49                if node in groups:
50                    e.meta[FINAL] = True
51                    break
52
53        having = select.args.get("having")
54        if having:
55            for node, *_ in having.walk():
56                if node in groups:
57                    having.meta[FINAL] = True
58                    break
59
60    def _simplify(expression, root=True):
61        if expression.meta.get(FINAL):
62            return expression
63
64        # Pre-order transformations
65        node = expression
66        node = rewrite_between(node)
67        node = uniq_sort(node, generate, root)
68        node = absorb_and_eliminate(node, root)
69        node = simplify_concat(node)
70
71        exp.replace_children(node, lambda e: _simplify(e, False))
72
73        # Post-order transformations
74        node = simplify_not(node)
75        node = flatten(node)
76        node = simplify_connectors(node, root)
77        node = remove_compliments(node, root)
78        node = simplify_coalesce(node)
79        node.parent = expression.parent
80        node = simplify_literals(node, root)
81        node = simplify_equality(node)
82        node = simplify_parens(node)
83        node = simplify_datetrunc_predicate(node)
84
85        if root:
86            expression.replace(node)
87
88        return node
89
90    expression = while_changing(expression, _simplify)
91    remove_where_true(expression)
92    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
 95def catch(*exceptions):
 96    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 97
 98    def decorator(func):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
104
105        return wrapped
106
107    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
110def rewrite_between(expression: exp.Expression) -> exp.Expression:
111    """Rewrite x between y and z to x >= y AND x <= z.
112
113    This is done because comparison simplification is only done on lt/lte/gt/gte.
114    """
115    if isinstance(expression, exp.Between):
116        return exp.and_(
117            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
118            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
119            copy=False,
120        )
121    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
124def simplify_not(expression):
125    """
126    Demorgan's Law
127    NOT (x OR y) -> NOT x AND NOT y
128    NOT (x AND y) -> NOT x OR NOT y
129    """
130    if isinstance(expression, exp.Not):
131        if is_null(expression.this):
132            return exp.null()
133        if isinstance(expression.this, exp.Paren):
134            condition = expression.this.unnest()
135            if isinstance(condition, exp.And):
136                return exp.or_(
137                    exp.not_(condition.left, copy=False),
138                    exp.not_(condition.right, copy=False),
139                    copy=False,
140                )
141            if isinstance(condition, exp.Or):
142                return exp.and_(
143                    exp.not_(condition.left, copy=False),
144                    exp.not_(condition.right, copy=False),
145                    copy=False,
146                )
147            if is_null(condition):
148                return exp.null()
149        if always_true(expression.this):
150            return exp.false()
151        if is_false(expression.this):
152            return exp.true()
153        if isinstance(expression.this, exp.Not):
154            # double negation
155            # NOT NOT x -> x
156            return expression.this.this
157    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
160def flatten(expression):
161    """
162    A AND (B AND C) -> A AND B AND C
163    A OR (B OR C) -> A OR B OR C
164    """
165    if isinstance(expression, exp.Connector):
166        for node in expression.args.values():
167            child = node.unnest()
168            if isinstance(child, expression.__class__):
169                node.replace(child)
170    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
173def simplify_connectors(expression, root=True):
174    def _simplify_connectors(expression, left, right):
175        if left == right:
176            return left
177        if isinstance(expression, exp.And):
178            if is_false(left) or is_false(right):
179                return exp.false()
180            if is_null(left) or is_null(right):
181                return exp.null()
182            if always_true(left) and always_true(right):
183                return exp.true()
184            if always_true(left):
185                return right
186            if always_true(right):
187                return left
188            return _simplify_comparison(expression, left, right)
189        elif isinstance(expression, exp.Or):
190            if always_true(left) or always_true(right):
191                return exp.true()
192            if is_false(left) and is_false(right):
193                return exp.false()
194            if (
195                (is_null(left) and is_null(right))
196                or (is_null(left) and is_false(right))
197                or (is_false(left) and is_null(right))
198            ):
199                return exp.null()
200            if is_false(left):
201                return right
202            if is_false(right):
203                return left
204            return _simplify_comparison(expression, left, right, or_=True)
205
206    if isinstance(expression, exp.Connector):
207        return _flat_simplify(expression, _simplify_connectors, root)
208    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_compliments(expression, root=True):
291def remove_compliments(expression, root=True):
292    """
293    Removing compliments.
294
295    A AND NOT A -> FALSE
296    A OR NOT A -> TRUE
297    """
298    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
299        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
300
301        for a, b in itertools.permutations(expression.flatten(), 2):
302            if is_complement(a, b):
303                return compliment
304    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
307def uniq_sort(expression, generate, root=True):
308    """
309    Uniq and sort a connector.
310
311    C AND A AND B AND B -> A AND B AND C
312    """
313    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
314        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
315        flattened = tuple(expression.flatten())
316        deduped = {generate(e): e for e in flattened}
317        arr = tuple(deduped.items())
318
319        # check if the operands are already sorted, if not sort them
320        # A AND C AND B -> A AND B AND C
321        for i, (sql, e) in enumerate(arr[1:]):
322            if sql < arr[i][0]:
323                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
324                break
325        else:
326            # we didn't have to sort but maybe we need to dedup
327            if len(deduped) < len(flattened):
328                expression = result_func(*deduped.values(), copy=False)
329
330    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
333def absorb_and_eliminate(expression, root=True):
334    """
335    absorption:
336        A AND (A OR B) -> A
337        A OR (A AND B) -> A
338        A AND (NOT A OR B) -> A AND B
339        A OR (NOT A AND B) -> A OR B
340    elimination:
341        (A AND B) OR (A AND NOT B) -> A
342        (A OR B) AND (A OR NOT B) -> A
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        kind = exp.Or if isinstance(expression, exp.And) else exp.And
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if isinstance(a, kind):
349                aa, ab = a.unnest_operands()
350
351                # absorb
352                if is_complement(b, aa):
353                    aa.replace(exp.true() if kind == exp.And else exp.false())
354                elif is_complement(b, ab):
355                    ab.replace(exp.true() if kind == exp.And else exp.false())
356                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
357                    a.replace(exp.false() if kind == exp.And else exp.true())
358                elif isinstance(b, kind):
359                    # eliminate
360                    rhs = b.unnest_operands()
361                    ba, bb = rhs
362
363                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
364                        a.replace(aa)
365                        b.replace(aa)
366                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
367                        a.replace(ab)
368                        b.replace(ab)
369
370    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_equality(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
454def simplify_literals(expression, root=True):
455    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
456        return _flat_simplify(expression, _simplify_binary, root)
457
458    if isinstance(expression, exp.Neg):
459        this = expression.this
460        if this.is_number:
461            value = this.name
462            if value[0] == "-":
463                return exp.Literal.number(value[1:])
464            return exp.Literal.number(f"-{value}")
465
466    return expression
def simplify_parens(expression):
529def simplify_parens(expression):
530    if not isinstance(expression, exp.Paren):
531        return expression
532
533    this = expression.this
534    parent = expression.parent
535
536    if not isinstance(this, exp.Select) and (
537        not isinstance(parent, (exp.Condition, exp.Binary))
538        or isinstance(parent, exp.Paren)
539        or not isinstance(this, exp.Binary)
540        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
541        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
542        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
543        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
544    ):
545        return this
546    return expression
def simplify_coalesce(expression):
556def simplify_coalesce(expression):
557    # COALESCE(x) -> x
558    if (
559        isinstance(expression, exp.Coalesce)
560        and not expression.expressions
561        # COALESCE is also used as a Spark partitioning hint
562        and not isinstance(expression.parent, exp.Hint)
563    ):
564        return expression.this
565
566    if not isinstance(expression, COMPARISONS):
567        return expression
568
569    if isinstance(expression.left, exp.Coalesce):
570        coalesce = expression.left
571        other = expression.right
572    elif isinstance(expression.right, exp.Coalesce):
573        coalesce = expression.right
574        other = expression.left
575    else:
576        return expression
577
578    # This transformation is valid for non-constants,
579    # but it really only does anything if they are both constants.
580    if not isinstance(other, CONSTANTS):
581        return expression
582
583    # Find the first constant arg
584    for arg_index, arg in enumerate(coalesce.expressions):
585        if isinstance(arg, CONSTANTS):
586            break
587    else:
588        return expression
589
590    coalesce.set("expressions", coalesce.expressions[:arg_index])
591
592    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
593    # since we already remove COALESCE at the top of this function.
594    coalesce = coalesce if coalesce.expressions else coalesce.this
595
596    # This expression is more complex than when we started, but it will get simplified further
597    return exp.paren(
598        exp.or_(
599            exp.and_(
600                coalesce.is_(exp.null()).not_(copy=False),
601                expression.copy(),
602                copy=False,
603            ),
604            exp.and_(
605                coalesce.is_(exp.null()),
606                type(expression)(this=arg.copy(), expression=other.copy()),
607                copy=False,
608            ),
609            copy=False,
610        )
611    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
618def simplify_concat(expression):
619    """Reduces all groups that contain string literals by concatenating them."""
620    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
621        return expression
622
623    new_args = []
624    for is_string_group, group in itertools.groupby(
625        expression.expressions or expression.flatten(), lambda e: e.is_string
626    ):
627        if is_string_group:
628            new_args.append(exp.Literal.string("".join(string.name for string in group)))
629        else:
630            new_args.extend(group)
631
632    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
633    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
634    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', 'OUTER'), ('', 'INNER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
776def remove_where_true(expression):
777    for where in expression.find_all(exp.Where):
778        if always_true(where.this):
779            where.parent.set("where", None)
780    for join in expression.find_all(exp.Join):
781        if (
782            always_true(join.args.get("on"))
783            and not join.args.get("using")
784            and not join.args.get("method")
785            and (join.side, join.kind) in JOINS
786        ):
787            join.set("on", None)
788            join.set("side", None)
789            join.set("kind", "CROSS")
def always_true(expression):
792def always_true(expression):
793    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
794        expression, exp.Literal
795    )
def is_complement(a, b):
798def is_complement(a, b):
799    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
802def is_false(a: exp.Expression) -> bool:
803    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
806def is_null(a: exp.Expression) -> bool:
807    return type(a) is exp.Null
def eval_boolean(expression, a, b):
810def eval_boolean(expression, a, b):
811    if isinstance(expression, (exp.EQ, exp.Is)):
812        return boolean_literal(a == b)
813    if isinstance(expression, exp.NEQ):
814        return boolean_literal(a != b)
815    if isinstance(expression, exp.GT):
816        return boolean_literal(a > b)
817    if isinstance(expression, exp.GTE):
818        return boolean_literal(a >= b)
819    if isinstance(expression, exp.LT):
820        return boolean_literal(a < b)
821    if isinstance(expression, exp.LTE):
822        return boolean_literal(a <= b)
823    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
826def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
827    if isinstance(value, datetime.datetime):
828        return value.date()
829    if isinstance(value, datetime.date):
830        return value
831    try:
832        return datetime.datetime.fromisoformat(value).date()
833    except ValueError:
834        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
837def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
838    if isinstance(value, datetime.datetime):
839        return value
840    if isinstance(value, datetime.date):
841        return datetime.datetime(year=value.year, month=value.month, day=value.day)
842    try:
843        return datetime.datetime.fromisoformat(value)
844    except ValueError:
845        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
848def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
849    if not value:
850        return None
851    if to.is_type(exp.DataType.Type.DATE):
852        return cast_as_date(value)
853    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
854        return cast_as_datetime(value)
855    return None
def extract_date(cast: sqlglot.expressions.Cast) -> Optional[datetime.date]:
858def extract_date(cast: exp.Cast) -> t.Optional[t.Union[datetime.date, datetime.date]]:
859    value: t.Any
860    if isinstance(cast.this, exp.Literal):
861        value = cast.this.name
862    elif isinstance(cast.this, exp.Cast):
863        value = extract_date(cast.this)
864    else:
865        return None
866    return cast_value(value, cast.to)
def extract_interval(expression):
869def extract_interval(expression):
870    n = int(expression.name)
871    unit = expression.text("unit").lower()
872
873    try:
874        return interval(unit, n)
875    except (UnsupportedUnit, ModuleNotFoundError):
876        return None
def date_literal(date):
879def date_literal(date):
880    return exp.cast(
881        exp.Literal.string(date),
882        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
883    )
def interval(unit: str, n: int = 1):
886def interval(unit: str, n: int = 1):
887    from dateutil.relativedelta import relativedelta
888
889    if unit == "year":
890        return relativedelta(years=1 * n)
891    if unit == "quarter":
892        return relativedelta(months=3 * n)
893    if unit == "month":
894        return relativedelta(months=1 * n)
895    if unit == "week":
896        return relativedelta(weeks=1 * n)
897    if unit == "day":
898        return relativedelta(days=1 * n)
899    if unit == "hour":
900        return relativedelta(hours=1 * n)
901    if unit == "minute":
902        return relativedelta(minutes=1 * n)
903    if unit == "second":
904        return relativedelta(seconds=1 * n)
905
906    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
909def date_floor(d: datetime.date, unit: str) -> datetime.date:
910    if unit == "year":
911        return d.replace(month=1, day=1)
912    if unit == "quarter":
913        if d.month <= 3:
914            return d.replace(month=1, day=1)
915        elif d.month <= 6:
916            return d.replace(month=4, day=1)
917        elif d.month <= 9:
918            return d.replace(month=7, day=1)
919        else:
920            return d.replace(month=10, day=1)
921    if unit == "month":
922        return d.replace(month=d.month, day=1)
923    if unit == "week":
924        # Assuming week starts on Monday (0) and ends on Sunday (6)
925        return d - datetime.timedelta(days=d.weekday())
926    if unit == "day":
927        return d
928
929    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
932def date_ceil(d: datetime.date, unit: str) -> datetime.date:
933    floor = date_floor(d, unit)
934
935    if floor == d:
936        return d
937
938    return floor + interval(unit)
def boolean_literal(condition):
941def boolean_literal(condition):
942    return exp.true() if condition else exp.false()