sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7 8 9def canonicalize(expression: exp.Expression) -> exp.Expression: 10 """Converts a sql expression into a standard form. 11 12 This method relies on annotate_types because many of the 13 conversions rely on type inference. 14 15 Args: 16 expression: The expression to canonicalize. 17 """ 18 exp.replace_children(expression, canonicalize) 19 20 expression = add_text_to_concat(expression) 21 expression = replace_date_funcs(expression) 22 expression = coerce_type(expression) 23 expression = remove_redundant_casts(expression) 24 expression = ensure_bool_predicates(expression) 25 expression = remove_ascending_order(expression) 26 27 return expression 28 29 30def add_text_to_concat(node: exp.Expression) -> exp.Expression: 31 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 32 node = exp.Concat(expressions=[node.left, node.right]) 33 return node 34 35 36def replace_date_funcs(node: exp.Expression) -> exp.Expression: 37 if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): 38 return exp.cast(node.this, to=exp.DataType.Type.DATE) 39 if isinstance(node, exp.Timestamp) and not node.expression: 40 return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP) 41 return node 42 43 44# Expression type to transform -> arg key -> (allowed types, type to cast to) 45ARG_TYPES: t.Dict[ 46 t.Type[exp.Expression], t.Dict[str, t.Tuple[t.Iterable[exp.DataType.Type], exp.DataType.Type]] 47] = { 48 exp.DateAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)}, 49 exp.DateSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)}, 50 exp.DatetimeAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)}, 51 exp.DatetimeSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)}, 52 exp.Extract: {"expression": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)}, 53} 54 55 56def coerce_type(node: exp.Expression) -> exp.Expression: 57 if isinstance(node, exp.Binary): 58 _coerce_date(node.left, node.right) 59 elif isinstance(node, exp.Between): 60 _coerce_date(node.this, node.args["low"]) 61 else: 62 arg_types = ARG_TYPES.get(node.__class__) 63 if arg_types: 64 for arg_key, (allowed, to) in arg_types.items(): 65 arg = node.args.get(arg_key) 66 if arg and not arg.type.is_type(*allowed): 67 _replace_cast(arg, to) 68 return node 69 70 71def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 72 if ( 73 isinstance(expression, exp.Cast) 74 and expression.to.type 75 and expression.this.type 76 and expression.to.type.this == expression.this.type.this 77 ): 78 return expression.this 79 return expression 80 81 82def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: 83 if isinstance(expression, exp.Connector): 84 _replace_int_predicate(expression.left) 85 _replace_int_predicate(expression.right) 86 87 elif isinstance(expression, (exp.Where, exp.Having, exp.If)): 88 _replace_int_predicate(expression.this) 89 90 return expression 91 92 93def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 94 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 95 # Convert ORDER BY a ASC to ORDER BY a 96 expression.set("desc", None) 97 98 return expression 99 100 101def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: 102 for a, b in itertools.permutations([a, b]): 103 if ( 104 a.type 105 and a.type.this == exp.DataType.Type.DATE 106 and b.type 107 and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) 108 ): 109 _replace_cast(b, exp.DataType.Type.DATE) 110 111 112def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: 113 node.replace(exp.cast(node.copy(), to=to)) 114 115 116def _replace_int_predicate(expression: exp.Expression) -> None: 117 if isinstance(expression, exp.Coalesce): 118 for _, child in expression.iter_expressions(): 119 _replace_int_predicate(child) 120 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 121 expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
10def canonicalize(expression: exp.Expression) -> exp.Expression: 11 """Converts a sql expression into a standard form. 12 13 This method relies on annotate_types because many of the 14 conversions rely on type inference. 15 16 Args: 17 expression: The expression to canonicalize. 18 """ 19 exp.replace_children(expression, canonicalize) 20 21 expression = add_text_to_concat(expression) 22 expression = replace_date_funcs(expression) 23 expression = coerce_type(expression) 24 expression = remove_redundant_casts(expression) 25 expression = ensure_bool_predicates(expression) 26 expression = remove_ascending_order(expression) 27 28 return expression
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
37def replace_date_funcs(node: exp.Expression) -> exp.Expression: 38 if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): 39 return exp.cast(node.this, to=exp.DataType.Type.DATE) 40 if isinstance(node, exp.Timestamp) and not node.expression: 41 return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP) 42 return node
ARG_TYPES: Dict[Type[sqlglot.expressions.Expression], Dict[str, Tuple[Iterable[sqlglot.expressions.DataType.Type], sqlglot.expressions.DataType.Type]]] =
{<class 'sqlglot.expressions.DateAdd'>: {'this': ({<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIME: 'TIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.DATETIME64: 'DATETIME64'>, <Type.TIMETZ: 'TIMETZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATE: 'DATE'>}, <Type.DATE: 'DATE'>)}, <class 'sqlglot.expressions.DateSub'>: {'this': ({<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIME: 'TIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.DATETIME64: 'DATETIME64'>, <Type.TIMETZ: 'TIMETZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATE: 'DATE'>}, <Type.DATE: 'DATE'>)}, <class 'sqlglot.expressions.DatetimeAdd'>: {'this': ({<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIME: 'TIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.DATETIME64: 'DATETIME64'>, <Type.TIMETZ: 'TIMETZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATE: 'DATE'>}, <Type.DATETIME: 'DATETIME'>)}, <class 'sqlglot.expressions.DatetimeSub'>: {'this': ({<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIME: 'TIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.DATETIME64: 'DATETIME64'>, <Type.TIMETZ: 'TIMETZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATE: 'DATE'>}, <Type.DATETIME: 'DATETIME'>)}, <class 'sqlglot.expressions.Extract'>: {'expression': ({<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIME: 'TIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.DATETIME64: 'DATETIME64'>, <Type.TIMETZ: 'TIMETZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATE: 'DATE'>}, <Type.DATETIME: 'DATETIME'>)}}
57def coerce_type(node: exp.Expression) -> exp.Expression: 58 if isinstance(node, exp.Binary): 59 _coerce_date(node.left, node.right) 60 elif isinstance(node, exp.Between): 61 _coerce_date(node.this, node.args["low"]) 62 else: 63 arg_types = ARG_TYPES.get(node.__class__) 64 if arg_types: 65 for arg_key, (allowed, to) in arg_types.items(): 66 arg = node.args.get(arg_key) 67 if arg and not arg.type.is_type(*allowed): 68 _replace_cast(arg, to) 69 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
def
ensure_bool_predicates( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
83def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: 84 if isinstance(expression, exp.Connector): 85 _replace_int_predicate(expression.left) 86 _replace_int_predicate(expression.right) 87 88 elif isinstance(expression, (exp.Where, exp.Having, exp.If)): 89 _replace_int_predicate(expression.this) 90 91 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: