Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import datetime
  4import functools
  5import typing as t
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.helper import ensure_list, subclasses
 10from sqlglot.optimizer.scope import Scope, traverse_scope
 11from sqlglot.schema import Schema, ensure_schema
 12
 13if t.TYPE_CHECKING:
 14    B = t.TypeVar("B", bound=exp.Binary)
 15
 16    BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
 17    BinaryCoercions = t.Dict[
 18        t.Tuple[exp.DataType.Type, exp.DataType.Type],
 19        BinaryCoercionFunc,
 20    ]
 21
 22
 23# Interval units that operate on date components
 24DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
 25
 26
 27def annotate_types(
 28    expression: E,
 29    schema: t.Optional[t.Dict | Schema] = None,
 30    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 31    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 32) -> E:
 33    """
 34    Infers the types of an expression, annotating its AST accordingly.
 35
 36    Example:
 37        >>> import sqlglot
 38        >>> schema = {"y": {"cola": "SMALLINT"}}
 39        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 40        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 41        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 42        <Type.DOUBLE: 'DOUBLE'>
 43
 44    Args:
 45        expression: Expression to annotate.
 46        schema: Database schema.
 47        annotators: Maps expression type to corresponding annotation function.
 48        coerces_to: Maps expression type to set of types that it can be coerced into.
 49
 50    Returns:
 51        The expression annotated with types.
 52    """
 53
 54    schema = ensure_schema(schema)
 55
 56    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 57
 58
 59def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 60    return lambda self, e: self._annotate_with_type(e, data_type)
 61
 62
 63def _is_iso_date(text: str) -> bool:
 64    try:
 65        datetime.date.fromisoformat(text)
 66        return True
 67    except ValueError:
 68        return False
 69
 70
 71def _is_iso_datetime(text: str) -> bool:
 72    try:
 73        datetime.datetime.fromisoformat(text)
 74        return True
 75    except ValueError:
 76        return False
 77
 78
 79def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
 80    date_text = l.name
 81    unit = r.text("unit").lower()
 82
 83    is_iso_date = _is_iso_date(date_text)
 84
 85    if is_iso_date and unit in DATE_UNITS:
 86        l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
 87        return exp.DataType.Type.DATE
 88
 89    # An ISO date is also an ISO datetime, but not vice versa
 90    if is_iso_date or _is_iso_datetime(date_text):
 91        l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
 92        return exp.DataType.Type.DATETIME
 93
 94    return exp.DataType.Type.UNKNOWN
 95
 96
 97def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
 98    unit = r.text("unit").lower()
 99    if unit not in DATE_UNITS:
100        return exp.DataType.Type.DATETIME
101    return l.type.this if l.type else exp.DataType.Type.UNKNOWN
102
103
104def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
105    @functools.wraps(func)
106    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
107        return func(r, l)
108
109    return _swapped
110
111
112def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
113    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
114
115
116class _TypeAnnotator(type):
117    def __new__(cls, clsname, bases, attrs):
118        klass = super().__new__(cls, clsname, bases, attrs)
119
120        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
121        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
122        text_precedence = (
123            exp.DataType.Type.TEXT,
124            exp.DataType.Type.NVARCHAR,
125            exp.DataType.Type.VARCHAR,
126            exp.DataType.Type.NCHAR,
127            exp.DataType.Type.CHAR,
128        )
129        numeric_precedence = (
130            exp.DataType.Type.DOUBLE,
131            exp.DataType.Type.FLOAT,
132            exp.DataType.Type.DECIMAL,
133            exp.DataType.Type.BIGINT,
134            exp.DataType.Type.INT,
135            exp.DataType.Type.SMALLINT,
136            exp.DataType.Type.TINYINT,
137        )
138        timelike_precedence = (
139            exp.DataType.Type.TIMESTAMPLTZ,
140            exp.DataType.Type.TIMESTAMPTZ,
141            exp.DataType.Type.TIMESTAMP,
142            exp.DataType.Type.DATETIME,
143            exp.DataType.Type.DATE,
144        )
145
146        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
147            coerces_to = set()
148            for data_type in type_precedence:
149                klass.COERCES_TO[data_type] = coerces_to.copy()
150                coerces_to |= {data_type}
151
152        return klass
153
154
155class TypeAnnotator(metaclass=_TypeAnnotator):
156    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
157        exp.DataType.Type.BIGINT: {
158            exp.ApproxDistinct,
159            exp.ArraySize,
160            exp.Count,
161            exp.Length,
162        },
163        exp.DataType.Type.BOOLEAN: {
164            exp.Between,
165            exp.Boolean,
166            exp.In,
167            exp.RegexpLike,
168        },
169        exp.DataType.Type.DATE: {
170            exp.CurrentDate,
171            exp.Date,
172            exp.DateFromParts,
173            exp.DateStrToDate,
174            exp.DateTrunc,
175            exp.DiToDate,
176            exp.StrToDate,
177            exp.TimeStrToDate,
178            exp.TsOrDsToDate,
179        },
180        exp.DataType.Type.DATETIME: {
181            exp.CurrentDatetime,
182            exp.DatetimeAdd,
183            exp.DatetimeSub,
184        },
185        exp.DataType.Type.DOUBLE: {
186            exp.ApproxQuantile,
187            exp.Avg,
188            exp.Exp,
189            exp.Ln,
190            exp.Log,
191            exp.Log2,
192            exp.Log10,
193            exp.Pow,
194            exp.Quantile,
195            exp.Round,
196            exp.SafeDivide,
197            exp.Sqrt,
198            exp.Stddev,
199            exp.StddevPop,
200            exp.StddevSamp,
201            exp.Variance,
202            exp.VariancePop,
203        },
204        exp.DataType.Type.INT: {
205            exp.Ceil,
206            exp.DateDiff,
207            exp.DatetimeDiff,
208            exp.Extract,
209            exp.TimestampDiff,
210            exp.TimeDiff,
211            exp.DateToDi,
212            exp.Floor,
213            exp.Levenshtein,
214            exp.StrPosition,
215            exp.TsOrDiToDi,
216        },
217        exp.DataType.Type.TIMESTAMP: {
218            exp.CurrentTime,
219            exp.CurrentTimestamp,
220            exp.StrToTime,
221            exp.TimeAdd,
222            exp.TimeStrToTime,
223            exp.TimeSub,
224            exp.Timestamp,
225            exp.TimestampAdd,
226            exp.TimestampSub,
227            exp.UnixToTime,
228        },
229        exp.DataType.Type.TINYINT: {
230            exp.Day,
231            exp.Month,
232            exp.Week,
233            exp.Year,
234        },
235        exp.DataType.Type.VARCHAR: {
236            exp.ArrayConcat,
237            exp.Concat,
238            exp.ConcatWs,
239            exp.DateToDateStr,
240            exp.GroupConcat,
241            exp.Initcap,
242            exp.Lower,
243            exp.SafeConcat,
244            exp.SafeDPipe,
245            exp.Substring,
246            exp.TimeToStr,
247            exp.TimeToTimeStr,
248            exp.Trim,
249            exp.TsOrDsToDateStr,
250            exp.UnixToStr,
251            exp.UnixToTimeStr,
252            exp.Upper,
253        },
254    }
255
256    ANNOTATORS: t.Dict = {
257        **{
258            expr_type: lambda self, e: self._annotate_unary(e)
259            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
260        },
261        **{
262            expr_type: lambda self, e: self._annotate_binary(e)
263            for expr_type in subclasses(exp.__name__, exp.Binary)
264        },
265        **{
266            expr_type: _annotate_with_type_lambda(data_type)
267            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
268            for expr_type in expressions
269        },
270        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
271        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
272        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
273        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
274        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
275        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
276        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
277        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
278        exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
279        exp.DateSub: lambda self, e: self._annotate_dateadd(e),
280        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
281        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
282        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
283        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
284        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
285        exp.Literal: lambda self, e: self._annotate_literal(e),
286        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
287        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
288        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
289        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
290        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
291        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
292        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
293    }
294
295    NESTED_TYPES = {
296        exp.DataType.Type.ARRAY,
297    }
298
299    # Specifies what types a given type can be coerced into (autofilled)
300    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
301
302    # Coercion functions for binary operations.
303    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
304    BINARY_COERCIONS: BinaryCoercions = {
305        **swap_all(
306            {
307                (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
308                for t in exp.DataType.TEXT_TYPES
309            }
310        ),
311        **swap_all(
312            {
313                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
314            }
315        ),
316    }
317
318    def __init__(
319        self,
320        schema: Schema,
321        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
322        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
323        binary_coercions: t.Optional[BinaryCoercions] = None,
324    ) -> None:
325        self.schema = schema
326        self.annotators = annotators or self.ANNOTATORS
327        self.coerces_to = coerces_to or self.COERCES_TO
328        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
329
330        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
331        self._visited: t.Set[int] = set()
332
333    def _set_type(
334        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
335    ) -> None:
336        expression.type = target_type  # type: ignore
337        self._visited.add(id(expression))
338
339    def annotate(self, expression: E) -> E:
340        for scope in traverse_scope(expression):
341            selects = {}
342            for name, source in scope.sources.items():
343                if not isinstance(source, Scope):
344                    continue
345                if isinstance(source.expression, exp.UDTF):
346                    values = []
347
348                    if isinstance(source.expression, exp.Lateral):
349                        if isinstance(source.expression.this, exp.Explode):
350                            values = [source.expression.this.this]
351                    else:
352                        values = source.expression.expressions[0].expressions
353
354                    if not values:
355                        continue
356
357                    selects[name] = {
358                        alias: column
359                        for alias, column in zip(
360                            source.expression.alias_column_names,
361                            values,
362                        )
363                    }
364                else:
365                    selects[name] = {
366                        select.alias_or_name: select for select in source.expression.selects
367                    }
368
369            # First annotate the current scope's column references
370            for col in scope.columns:
371                if not col.table:
372                    continue
373
374                source = scope.sources.get(col.table)
375                if isinstance(source, exp.Table):
376                    self._set_type(col, self.schema.get_column_type(source, col))
377                elif source and col.table in selects and col.name in selects[col.table]:
378                    self._set_type(col, selects[col.table][col.name].type)
379
380            # Then (possibly) annotate the remaining expressions in the scope
381            self._maybe_annotate(scope.expression)
382
383        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
384
385    def _maybe_annotate(self, expression: E) -> E:
386        if id(expression) in self._visited:
387            return expression  # We've already inferred the expression's type
388
389        annotator = self.annotators.get(expression.__class__)
390
391        return (
392            annotator(self, expression)
393            if annotator
394            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
395        )
396
397    def _annotate_args(self, expression: E) -> E:
398        for _, value in expression.iter_expressions():
399            self._maybe_annotate(value)
400
401        return expression
402
403    def _maybe_coerce(
404        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
405    ) -> exp.DataType | exp.DataType.Type:
406        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
407        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
408
409        # We propagate the NULL / UNKNOWN types upwards if found
410        if exp.DataType.Type.NULL in (type1_value, type2_value):
411            return exp.DataType.Type.NULL
412        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
413            return exp.DataType.Type.UNKNOWN
414
415        if type1_value in self.NESTED_TYPES:
416            return type1
417        if type2_value in self.NESTED_TYPES:
418            return type2
419
420        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
421
422    # Note: the following "no_type_check" decorators were added because mypy was yelling due
423    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
424    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
425
426    @t.no_type_check
427    def _annotate_binary(self, expression: B) -> B:
428        self._annotate_args(expression)
429
430        left, right = expression.left, expression.right
431        left_type, right_type = left.type.this, right.type.this
432
433        if isinstance(expression, exp.Connector):
434            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
435                self._set_type(expression, exp.DataType.Type.NULL)
436            elif exp.DataType.Type.NULL in (left_type, right_type):
437                self._set_type(
438                    expression,
439                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
440                )
441            else:
442                self._set_type(expression, exp.DataType.Type.BOOLEAN)
443        elif isinstance(expression, exp.Predicate):
444            self._set_type(expression, exp.DataType.Type.BOOLEAN)
445        elif (left_type, right_type) in self.binary_coercions:
446            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
447        else:
448            self._set_type(expression, self._maybe_coerce(left_type, right_type))
449
450        return expression
451
452    @t.no_type_check
453    def _annotate_unary(self, expression: E) -> E:
454        self._annotate_args(expression)
455
456        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
457            self._set_type(expression, exp.DataType.Type.BOOLEAN)
458        else:
459            self._set_type(expression, expression.this.type)
460
461        return expression
462
463    @t.no_type_check
464    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
465        if expression.is_string:
466            self._set_type(expression, exp.DataType.Type.VARCHAR)
467        elif expression.is_int:
468            self._set_type(expression, exp.DataType.Type.INT)
469        else:
470            self._set_type(expression, exp.DataType.Type.DOUBLE)
471
472        return expression
473
474    @t.no_type_check
475    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
476        self._set_type(expression, target_type)
477        return self._annotate_args(expression)
478
479    @t.no_type_check
480    def _annotate_by_args(
481        self, expression: E, *args: str, promote: bool = False, array: bool = False
482    ) -> E:
483        self._annotate_args(expression)
484
485        expressions: t.List[exp.Expression] = []
486        for arg in args:
487            arg_expr = expression.args.get(arg)
488            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
489
490        last_datatype = None
491        for expr in expressions:
492            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
493
494        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
495
496        if promote:
497            if expression.type.this in exp.DataType.INTEGER_TYPES:
498                self._set_type(expression, exp.DataType.Type.BIGINT)
499            elif expression.type.this in exp.DataType.FLOAT_TYPES:
500                self._set_type(expression, exp.DataType.Type.DOUBLE)
501
502        if array:
503            self._set_type(
504                expression,
505                exp.DataType(
506                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
507                ),
508            )
509
510        return expression
511
512    def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
513        self._annotate_args(expression)
514
515        if expression.this.type.this in exp.DataType.TEXT_TYPES:
516            datatype = _coerce_literal_and_interval(expression.this, expression.interval())
517        elif (
518            expression.this.type.is_type(exp.DataType.Type.DATE)
519            and expression.text("unit").lower() not in DATE_UNITS
520        ):
521            datatype = exp.DataType.Type.DATETIME
522        else:
523            datatype = expression.this.type
524
525        self._set_type(expression, datatype)
526        return expression
DATE_UNITS = {'quarter', 'day', 'year', 'week', 'month', 'year_month'}
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
28def annotate_types(
29    expression: E,
30    schema: t.Optional[t.Dict | Schema] = None,
31    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
32    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
33) -> E:
34    """
35    Infers the types of an expression, annotating its AST accordingly.
36
37    Example:
38        >>> import sqlglot
39        >>> schema = {"y": {"cola": "SMALLINT"}}
40        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
41        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
42        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
43        <Type.DOUBLE: 'DOUBLE'>
44
45    Args:
46        expression: Expression to annotate.
47        schema: Database schema.
48        annotators: Maps expression type to corresponding annotation function.
49        coerces_to: Maps expression type to set of types that it can be coerced into.
50
51    Returns:
52        The expression annotated with types.
53    """
54
55    schema = ensure_schema(schema)
56
57    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

105def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
106    @functools.wraps(func)
107    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
108        return func(r, l)
109
110    return _swapped
113def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
114    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
class TypeAnnotator:
156class TypeAnnotator(metaclass=_TypeAnnotator):
157    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
158        exp.DataType.Type.BIGINT: {
159            exp.ApproxDistinct,
160            exp.ArraySize,
161            exp.Count,
162            exp.Length,
163        },
164        exp.DataType.Type.BOOLEAN: {
165            exp.Between,
166            exp.Boolean,
167            exp.In,
168            exp.RegexpLike,
169        },
170        exp.DataType.Type.DATE: {
171            exp.CurrentDate,
172            exp.Date,
173            exp.DateFromParts,
174            exp.DateStrToDate,
175            exp.DateTrunc,
176            exp.DiToDate,
177            exp.StrToDate,
178            exp.TimeStrToDate,
179            exp.TsOrDsToDate,
180        },
181        exp.DataType.Type.DATETIME: {
182            exp.CurrentDatetime,
183            exp.DatetimeAdd,
184            exp.DatetimeSub,
185        },
186        exp.DataType.Type.DOUBLE: {
187            exp.ApproxQuantile,
188            exp.Avg,
189            exp.Exp,
190            exp.Ln,
191            exp.Log,
192            exp.Log2,
193            exp.Log10,
194            exp.Pow,
195            exp.Quantile,
196            exp.Round,
197            exp.SafeDivide,
198            exp.Sqrt,
199            exp.Stddev,
200            exp.StddevPop,
201            exp.StddevSamp,
202            exp.Variance,
203            exp.VariancePop,
204        },
205        exp.DataType.Type.INT: {
206            exp.Ceil,
207            exp.DateDiff,
208            exp.DatetimeDiff,
209            exp.Extract,
210            exp.TimestampDiff,
211            exp.TimeDiff,
212            exp.DateToDi,
213            exp.Floor,
214            exp.Levenshtein,
215            exp.StrPosition,
216            exp.TsOrDiToDi,
217        },
218        exp.DataType.Type.TIMESTAMP: {
219            exp.CurrentTime,
220            exp.CurrentTimestamp,
221            exp.StrToTime,
222            exp.TimeAdd,
223            exp.TimeStrToTime,
224            exp.TimeSub,
225            exp.Timestamp,
226            exp.TimestampAdd,
227            exp.TimestampSub,
228            exp.UnixToTime,
229        },
230        exp.DataType.Type.TINYINT: {
231            exp.Day,
232            exp.Month,
233            exp.Week,
234            exp.Year,
235        },
236        exp.DataType.Type.VARCHAR: {
237            exp.ArrayConcat,
238            exp.Concat,
239            exp.ConcatWs,
240            exp.DateToDateStr,
241            exp.GroupConcat,
242            exp.Initcap,
243            exp.Lower,
244            exp.SafeConcat,
245            exp.SafeDPipe,
246            exp.Substring,
247            exp.TimeToStr,
248            exp.TimeToTimeStr,
249            exp.Trim,
250            exp.TsOrDsToDateStr,
251            exp.UnixToStr,
252            exp.UnixToTimeStr,
253            exp.Upper,
254        },
255    }
256
257    ANNOTATORS: t.Dict = {
258        **{
259            expr_type: lambda self, e: self._annotate_unary(e)
260            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
261        },
262        **{
263            expr_type: lambda self, e: self._annotate_binary(e)
264            for expr_type in subclasses(exp.__name__, exp.Binary)
265        },
266        **{
267            expr_type: _annotate_with_type_lambda(data_type)
268            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
269            for expr_type in expressions
270        },
271        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
272        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
273        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
274        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
275        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
276        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
277        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
278        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
279        exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
280        exp.DateSub: lambda self, e: self._annotate_dateadd(e),
281        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
282        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
283        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
284        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
285        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
286        exp.Literal: lambda self, e: self._annotate_literal(e),
287        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
288        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
289        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
290        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
291        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
292        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
293        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
294    }
295
296    NESTED_TYPES = {
297        exp.DataType.Type.ARRAY,
298    }
299
300    # Specifies what types a given type can be coerced into (autofilled)
301    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
302
303    # Coercion functions for binary operations.
304    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
305    BINARY_COERCIONS: BinaryCoercions = {
306        **swap_all(
307            {
308                (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
309                for t in exp.DataType.TEXT_TYPES
310            }
311        ),
312        **swap_all(
313            {
314                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
315            }
316        ),
317    }
318
319    def __init__(
320        self,
321        schema: Schema,
322        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
323        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
324        binary_coercions: t.Optional[BinaryCoercions] = None,
325    ) -> None:
326        self.schema = schema
327        self.annotators = annotators or self.ANNOTATORS
328        self.coerces_to = coerces_to or self.COERCES_TO
329        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
330
331        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
332        self._visited: t.Set[int] = set()
333
334    def _set_type(
335        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
336    ) -> None:
337        expression.type = target_type  # type: ignore
338        self._visited.add(id(expression))
339
340    def annotate(self, expression: E) -> E:
341        for scope in traverse_scope(expression):
342            selects = {}
343            for name, source in scope.sources.items():
344                if not isinstance(source, Scope):
345                    continue
346                if isinstance(source.expression, exp.UDTF):
347                    values = []
348
349                    if isinstance(source.expression, exp.Lateral):
350                        if isinstance(source.expression.this, exp.Explode):
351                            values = [source.expression.this.this]
352                    else:
353                        values = source.expression.expressions[0].expressions
354
355                    if not values:
356                        continue
357
358                    selects[name] = {
359                        alias: column
360                        for alias, column in zip(
361                            source.expression.alias_column_names,
362                            values,
363                        )
364                    }
365                else:
366                    selects[name] = {
367                        select.alias_or_name: select for select in source.expression.selects
368                    }
369
370            # First annotate the current scope's column references
371            for col in scope.columns:
372                if not col.table:
373                    continue
374
375                source = scope.sources.get(col.table)
376                if isinstance(source, exp.Table):
377                    self._set_type(col, self.schema.get_column_type(source, col))
378                elif source and col.table in selects and col.name in selects[col.table]:
379                    self._set_type(col, selects[col.table][col.name].type)
380
381            # Then (possibly) annotate the remaining expressions in the scope
382            self._maybe_annotate(scope.expression)
383
384        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
385
386    def _maybe_annotate(self, expression: E) -> E:
387        if id(expression) in self._visited:
388            return expression  # We've already inferred the expression's type
389
390        annotator = self.annotators.get(expression.__class__)
391
392        return (
393            annotator(self, expression)
394            if annotator
395            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
396        )
397
398    def _annotate_args(self, expression: E) -> E:
399        for _, value in expression.iter_expressions():
400            self._maybe_annotate(value)
401
402        return expression
403
404    def _maybe_coerce(
405        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
406    ) -> exp.DataType | exp.DataType.Type:
407        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
408        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
409
410        # We propagate the NULL / UNKNOWN types upwards if found
411        if exp.DataType.Type.NULL in (type1_value, type2_value):
412            return exp.DataType.Type.NULL
413        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
414            return exp.DataType.Type.UNKNOWN
415
416        if type1_value in self.NESTED_TYPES:
417            return type1
418        if type2_value in self.NESTED_TYPES:
419            return type2
420
421        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
422
423    # Note: the following "no_type_check" decorators were added because mypy was yelling due
424    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
425    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
426
427    @t.no_type_check
428    def _annotate_binary(self, expression: B) -> B:
429        self._annotate_args(expression)
430
431        left, right = expression.left, expression.right
432        left_type, right_type = left.type.this, right.type.this
433
434        if isinstance(expression, exp.Connector):
435            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
436                self._set_type(expression, exp.DataType.Type.NULL)
437            elif exp.DataType.Type.NULL in (left_type, right_type):
438                self._set_type(
439                    expression,
440                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
441                )
442            else:
443                self._set_type(expression, exp.DataType.Type.BOOLEAN)
444        elif isinstance(expression, exp.Predicate):
445            self._set_type(expression, exp.DataType.Type.BOOLEAN)
446        elif (left_type, right_type) in self.binary_coercions:
447            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
448        else:
449            self._set_type(expression, self._maybe_coerce(left_type, right_type))
450
451        return expression
452
453    @t.no_type_check
454    def _annotate_unary(self, expression: E) -> E:
455        self._annotate_args(expression)
456
457        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
458            self._set_type(expression, exp.DataType.Type.BOOLEAN)
459        else:
460            self._set_type(expression, expression.this.type)
461
462        return expression
463
464    @t.no_type_check
465    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
466        if expression.is_string:
467            self._set_type(expression, exp.DataType.Type.VARCHAR)
468        elif expression.is_int:
469            self._set_type(expression, exp.DataType.Type.INT)
470        else:
471            self._set_type(expression, exp.DataType.Type.DOUBLE)
472
473        return expression
474
475    @t.no_type_check
476    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
477        self._set_type(expression, target_type)
478        return self._annotate_args(expression)
479
480    @t.no_type_check
481    def _annotate_by_args(
482        self, expression: E, *args: str, promote: bool = False, array: bool = False
483    ) -> E:
484        self._annotate_args(expression)
485
486        expressions: t.List[exp.Expression] = []
487        for arg in args:
488            arg_expr = expression.args.get(arg)
489            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
490
491        last_datatype = None
492        for expr in expressions:
493            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
494
495        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
496
497        if promote:
498            if expression.type.this in exp.DataType.INTEGER_TYPES:
499                self._set_type(expression, exp.DataType.Type.BIGINT)
500            elif expression.type.this in exp.DataType.FLOAT_TYPES:
501                self._set_type(expression, exp.DataType.Type.DOUBLE)
502
503        if array:
504            self._set_type(
505                expression,
506                exp.DataType(
507                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
508                ),
509            )
510
511        return expression
512
513    def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
514        self._annotate_args(expression)
515
516        if expression.this.type.this in exp.DataType.TEXT_TYPES:
517            datatype = _coerce_literal_and_interval(expression.this, expression.interval())
518        elif (
519            expression.this.type.is_type(exp.DataType.Type.DATE)
520            and expression.text("unit").lower() not in DATE_UNITS
521        ):
522            datatype = exp.DataType.Type.DATETIME
523        else:
524            datatype = expression.this.type
525
526        self._set_type(expression, datatype)
527        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
319    def __init__(
320        self,
321        schema: Schema,
322        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
323        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
324        binary_coercions: t.Optional[BinaryCoercions] = None,
325    ) -> None:
326        self.schema = schema
327        self.annotators = annotators or self.ANNOTATORS
328        self.coerces_to = coerces_to or self.COERCES_TO
329        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
330
331        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
332        self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Count'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.In'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.DiToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Stddev'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Levenshtein'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimestampAdd'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Day'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.Concat'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.CHAR: 'CHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.NCHAR: 'NCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TINYINT: 'TINYINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] = {(<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function _coerce_literal_and_interval>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_date_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function _coerce_date_and_interval>}
schema
annotators
coerces_to
binary_coercions
def annotate(self, expression: ~E) -> ~E:
340    def annotate(self, expression: E) -> E:
341        for scope in traverse_scope(expression):
342            selects = {}
343            for name, source in scope.sources.items():
344                if not isinstance(source, Scope):
345                    continue
346                if isinstance(source.expression, exp.UDTF):
347                    values = []
348
349                    if isinstance(source.expression, exp.Lateral):
350                        if isinstance(source.expression.this, exp.Explode):
351                            values = [source.expression.this.this]
352                    else:
353                        values = source.expression.expressions[0].expressions
354
355                    if not values:
356                        continue
357
358                    selects[name] = {
359                        alias: column
360                        for alias, column in zip(
361                            source.expression.alias_column_names,
362                            values,
363                        )
364                    }
365                else:
366                    selects[name] = {
367                        select.alias_or_name: select for select in source.expression.selects
368                    }
369
370            # First annotate the current scope's column references
371            for col in scope.columns:
372                if not col.table:
373                    continue
374
375                source = scope.sources.get(col.table)
376                if isinstance(source, exp.Table):
377                    self._set_type(col, self.schema.get_column_type(source, col))
378                elif source and col.table in selects and col.name in selects[col.table]:
379                    self._set_type(col, selects[col.table][col.name].type)
380
381            # Then (possibly) annotate the remaining expressions in the scope
382            self._maybe_annotate(scope.expression)
383
384        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions