sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime 8 9 10def canonicalize(expression: exp.Expression) -> exp.Expression: 11 """Converts a sql expression into a standard form. 12 13 This method relies on annotate_types because many of the 14 conversions rely on type inference. 15 16 Args: 17 expression: The expression to canonicalize. 18 """ 19 exp.replace_children(expression, canonicalize) 20 21 expression = add_text_to_concat(expression) 22 expression = replace_date_funcs(expression) 23 expression = coerce_type(expression) 24 expression = remove_redundant_casts(expression) 25 expression = ensure_bools(expression, _replace_int_predicate) 26 expression = remove_ascending_order(expression) 27 28 return expression 29 30 31def add_text_to_concat(node: exp.Expression) -> exp.Expression: 32 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 33 node = exp.Concat(expressions=[node.left, node.right]) 34 return node 35 36 37def replace_date_funcs(node: exp.Expression) -> exp.Expression: 38 if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): 39 return exp.cast(node.this, to=exp.DataType.Type.DATE) 40 if isinstance(node, exp.Timestamp) and not node.expression: 41 return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP) 42 return node 43 44 45def coerce_type(node: exp.Expression) -> exp.Expression: 46 if isinstance(node, exp.Binary): 47 _coerce_date(node.left, node.right) 48 elif isinstance(node, exp.Between): 49 _coerce_date(node.this, node.args["low"]) 50 elif isinstance(node, exp.Extract) and not node.expression.type.is_type( 51 *exp.DataType.TEMPORAL_TYPES 52 ): 53 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 54 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 55 _coerce_timeunit_arg(node.this, node.unit) 56 elif isinstance(node, exp.DateDiff): 57 _coerce_datediff_args(node) 58 59 return node 60 61 62def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 63 if ( 64 isinstance(expression, exp.Cast) 65 and expression.to.type 66 and expression.this.type 67 and expression.to.type.this == expression.this.type.this 68 ): 69 return expression.this 70 return expression 71 72 73def ensure_bools( 74 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 75) -> exp.Expression: 76 if isinstance(expression, exp.Connector): 77 replace_func(expression.left) 78 replace_func(expression.right) 79 elif isinstance(expression, exp.Not): 80 replace_func(expression.this) 81 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 82 elif isinstance(expression, exp.If) and not ( 83 isinstance(expression.parent, exp.Case) and expression.parent.this 84 ): 85 replace_func(expression.this) 86 elif isinstance(expression, (exp.Where, exp.Having)): 87 replace_func(expression.this) 88 89 return expression 90 91 92def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 93 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 94 # Convert ORDER BY a ASC to ORDER BY a 95 expression.set("desc", None) 96 97 return expression 98 99 100def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: 101 for a, b in itertools.permutations([a, b]): 102 if isinstance(b, exp.Interval): 103 a = _coerce_timeunit_arg(a, b.unit) 104 if ( 105 a.type 106 and a.type.this == exp.DataType.Type.DATE 107 and b.type 108 and b.type.this 109 not in ( 110 exp.DataType.Type.DATE, 111 exp.DataType.Type.INTERVAL, 112 ) 113 ): 114 _replace_cast(b, exp.DataType.Type.DATE) 115 116 117def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: 118 if not arg.type: 119 return arg 120 121 if arg.type.this in exp.DataType.TEXT_TYPES: 122 date_text = arg.name 123 is_iso_date_ = is_iso_date(date_text) 124 125 if is_iso_date_ and is_date_unit(unit): 126 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) 127 128 # An ISO date is also an ISO datetime, but not vice versa 129 if is_iso_date_ or is_iso_datetime(date_text): 130 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 131 132 elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): 133 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 134 135 return arg 136 137 138def _coerce_datediff_args(node: exp.DateDiff) -> None: 139 for e in (node.this, node.expression): 140 if e.type.this not in exp.DataType.TEMPORAL_TYPES: 141 e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) 142 143 144def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: 145 node.replace(exp.cast(node.copy(), to=to)) 146 147 148# this was originally designed for presto, there is a similar transform for tsql 149# this is different in that it only operates on int types, this is because 150# presto has a boolean type whereas tsql doesn't (people use bits) 151# with y as (select true as x) select x = 0 FROM y -- illegal presto query 152def _replace_int_predicate(expression: exp.Expression) -> None: 153 if isinstance(expression, exp.Coalesce): 154 for _, child in expression.iter_expressions(): 155 _replace_int_predicate(child) 156 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 157 expression.replace(expression.neq(0))
11def canonicalize(expression: exp.Expression) -> exp.Expression: 12 """Converts a sql expression into a standard form. 13 14 This method relies on annotate_types because many of the 15 conversions rely on type inference. 16 17 Args: 18 expression: The expression to canonicalize. 19 """ 20 exp.replace_children(expression, canonicalize) 21 22 expression = add_text_to_concat(expression) 23 expression = replace_date_funcs(expression) 24 expression = coerce_type(expression) 25 expression = remove_redundant_casts(expression) 26 expression = ensure_bools(expression, _replace_int_predicate) 27 expression = remove_ascending_order(expression) 28 29 return expression
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
38def replace_date_funcs(node: exp.Expression) -> exp.Expression: 39 if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): 40 return exp.cast(node.this, to=exp.DataType.Type.DATE) 41 if isinstance(node, exp.Timestamp) and not node.expression: 42 return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP) 43 return node
46def coerce_type(node: exp.Expression) -> exp.Expression: 47 if isinstance(node, exp.Binary): 48 _coerce_date(node.left, node.right) 49 elif isinstance(node, exp.Between): 50 _coerce_date(node.this, node.args["low"]) 51 elif isinstance(node, exp.Extract) and not node.expression.type.is_type( 52 *exp.DataType.TEMPORAL_TYPES 53 ): 54 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 55 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 56 _coerce_timeunit_arg(node.this, node.unit) 57 elif isinstance(node, exp.DateDiff): 58 _coerce_datediff_args(node) 59 60 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
def
ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
74def ensure_bools( 75 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 76) -> exp.Expression: 77 if isinstance(expression, exp.Connector): 78 replace_func(expression.left) 79 replace_func(expression.right) 80 elif isinstance(expression, exp.Not): 81 replace_func(expression.this) 82 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 83 elif isinstance(expression, exp.If) and not ( 84 isinstance(expression.parent, exp.Case) and expression.parent.this 85 ): 86 replace_func(expression.this) 87 elif isinstance(expression, (exp.Where, exp.Having)): 88 replace_func(expression.this) 89 90 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: