Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  8
  9
 10def canonicalize(expression: exp.Expression) -> exp.Expression:
 11    """Converts a sql expression into a standard form.
 12
 13    This method relies on annotate_types because many of the
 14    conversions rely on type inference.
 15
 16    Args:
 17        expression: The expression to canonicalize.
 18    """
 19    exp.replace_children(expression, canonicalize)
 20
 21    expression = add_text_to_concat(expression)
 22    expression = replace_date_funcs(expression)
 23    expression = coerce_type(expression)
 24    expression = remove_redundant_casts(expression)
 25    expression = ensure_bool_predicates(expression)
 26    expression = remove_ascending_order(expression)
 27
 28    return expression
 29
 30
 31def add_text_to_concat(node: exp.Expression) -> exp.Expression:
 32    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 33        node = exp.Concat(expressions=[node.left, node.right])
 34    return node
 35
 36
 37def replace_date_funcs(node: exp.Expression) -> exp.Expression:
 38    if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
 39        return exp.cast(node.this, to=exp.DataType.Type.DATE)
 40    if isinstance(node, exp.Timestamp) and not node.expression:
 41        return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
 42    return node
 43
 44
 45def coerce_type(node: exp.Expression) -> exp.Expression:
 46    if isinstance(node, exp.Binary):
 47        _coerce_date(node.left, node.right)
 48    elif isinstance(node, exp.Between):
 49        _coerce_date(node.this, node.args["low"])
 50    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
 51        *exp.DataType.TEMPORAL_TYPES
 52    ):
 53        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
 54    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
 55        _coerce_timeunit_arg(node.this, node.unit)
 56    elif isinstance(node, exp.DateDiff):
 57        _coerce_datediff_args(node)
 58
 59    return node
 60
 61
 62def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 63    if (
 64        isinstance(expression, exp.Cast)
 65        and expression.to.type
 66        and expression.this.type
 67        and expression.to.type.this == expression.this.type.this
 68    ):
 69        return expression.this
 70    return expression
 71
 72
 73def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
 74    if isinstance(expression, exp.Connector):
 75        _replace_int_predicate(expression.left)
 76        _replace_int_predicate(expression.right)
 77
 78    elif isinstance(expression, (exp.Where, exp.Having)) or (
 79        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
 80        isinstance(expression, exp.If)
 81        and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
 82    ):
 83        _replace_int_predicate(expression.this)
 84
 85    return expression
 86
 87
 88def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
 89    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
 90        # Convert ORDER BY a ASC to ORDER BY a
 91        expression.set("desc", None)
 92
 93    return expression
 94
 95
 96def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
 97    for a, b in itertools.permutations([a, b]):
 98        if isinstance(b, exp.Interval):
 99            a = _coerce_timeunit_arg(a, b.unit)
100        if (
101            a.type
102            and a.type.this == exp.DataType.Type.DATE
103            and b.type
104            and b.type.this
105            not in (
106                exp.DataType.Type.DATE,
107                exp.DataType.Type.INTERVAL,
108            )
109        ):
110            _replace_cast(b, exp.DataType.Type.DATE)
111
112
113def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
114    if not arg.type:
115        return arg
116
117    if arg.type.this in exp.DataType.TEXT_TYPES:
118        date_text = arg.name
119        is_iso_date_ = is_iso_date(date_text)
120
121        if is_iso_date_ and is_date_unit(unit):
122            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
123
124        # An ISO date is also an ISO datetime, but not vice versa
125        if is_iso_date_ or is_iso_datetime(date_text):
126            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
127
128    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
129        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
130
131    return arg
132
133
134def _coerce_datediff_args(node: exp.DateDiff) -> None:
135    for e in (node.this, node.expression):
136        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
137            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
138
139
140def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
141    node.replace(exp.cast(node.copy(), to=to))
142
143
144def _replace_int_predicate(expression: exp.Expression) -> None:
145    if isinstance(expression, exp.Coalesce):
146        for _, child in expression.iter_expressions():
147            _replace_int_predicate(child)
148    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
149        expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
def canonicalize( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
11def canonicalize(expression: exp.Expression) -> exp.Expression:
12    """Converts a sql expression into a standard form.
13
14    This method relies on annotate_types because many of the
15    conversions rely on type inference.
16
17    Args:
18        expression: The expression to canonicalize.
19    """
20    exp.replace_children(expression, canonicalize)
21
22    expression = add_text_to_concat(expression)
23    expression = replace_date_funcs(expression)
24    expression = coerce_type(expression)
25    expression = remove_redundant_casts(expression)
26    expression = ensure_bool_predicates(expression)
27    expression = remove_ascending_order(expression)
28
29    return expression

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
32def add_text_to_concat(node: exp.Expression) -> exp.Expression:
33    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
34        node = exp.Concat(expressions=[node.left, node.right])
35    return node
def replace_date_funcs(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
38def replace_date_funcs(node: exp.Expression) -> exp.Expression:
39    if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
40        return exp.cast(node.this, to=exp.DataType.Type.DATE)
41    if isinstance(node, exp.Timestamp) and not node.expression:
42        return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
43    return node
def coerce_type(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
46def coerce_type(node: exp.Expression) -> exp.Expression:
47    if isinstance(node, exp.Binary):
48        _coerce_date(node.left, node.right)
49    elif isinstance(node, exp.Between):
50        _coerce_date(node.this, node.args["low"])
51    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
52        *exp.DataType.TEMPORAL_TYPES
53    ):
54        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
55    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
56        _coerce_timeunit_arg(node.this, node.unit)
57    elif isinstance(node, exp.DateDiff):
58        _coerce_datediff_args(node)
59
60    return node
def remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
63def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
64    if (
65        isinstance(expression, exp.Cast)
66        and expression.to.type
67        and expression.this.type
68        and expression.to.type.this == expression.this.type.this
69    ):
70        return expression.this
71    return expression
def ensure_bool_predicates( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
74def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
75    if isinstance(expression, exp.Connector):
76        _replace_int_predicate(expression.left)
77        _replace_int_predicate(expression.right)
78
79    elif isinstance(expression, (exp.Where, exp.Having)) or (
80        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
81        isinstance(expression, exp.If)
82        and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
83    ):
84        _replace_int_predicate(expression.this)
85
86    return expression
def remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
89def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
90    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
91        # Convert ORDER BY a ASC to ORDER BY a
92        expression.set("desc", None)
93
94    return expression