Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import functools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot._typing import E
  8from sqlglot.helper import (
  9    ensure_list,
 10    is_date_unit,
 11    is_iso_date,
 12    is_iso_datetime,
 13    seq_get,
 14    subclasses,
 15)
 16from sqlglot.optimizer.scope import Scope, traverse_scope
 17from sqlglot.schema import Schema, ensure_schema
 18
 19if t.TYPE_CHECKING:
 20    B = t.TypeVar("B", bound=exp.Binary)
 21
 22    BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
 23    BinaryCoercions = t.Dict[
 24        t.Tuple[exp.DataType.Type, exp.DataType.Type],
 25        BinaryCoercionFunc,
 26    ]
 27
 28
 29def annotate_types(
 30    expression: E,
 31    schema: t.Optional[t.Dict | Schema] = None,
 32    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 33    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 34) -> E:
 35    """
 36    Infers the types of an expression, annotating its AST accordingly.
 37
 38    Example:
 39        >>> import sqlglot
 40        >>> schema = {"y": {"cola": "SMALLINT"}}
 41        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 42        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 43        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 44        <Type.DOUBLE: 'DOUBLE'>
 45
 46    Args:
 47        expression: Expression to annotate.
 48        schema: Database schema.
 49        annotators: Maps expression type to corresponding annotation function.
 50        coerces_to: Maps expression type to set of types that it can be coerced into.
 51
 52    Returns:
 53        The expression annotated with types.
 54    """
 55
 56    schema = ensure_schema(schema)
 57
 58    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 59
 60
 61def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 62    return lambda self, e: self._annotate_with_type(e, data_type)
 63
 64
 65def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
 66    date_text = l.name
 67    is_iso_date_ = is_iso_date(date_text)
 68
 69    if is_iso_date_ and is_date_unit(unit):
 70        return exp.DataType.Type.DATE
 71
 72    # An ISO date is also an ISO datetime, but not vice versa
 73    if is_iso_date_ or is_iso_datetime(date_text):
 74        return exp.DataType.Type.DATETIME
 75
 76    return exp.DataType.Type.UNKNOWN
 77
 78
 79def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
 80    if not is_date_unit(unit):
 81        return exp.DataType.Type.DATETIME
 82    return l.type.this if l.type else exp.DataType.Type.UNKNOWN
 83
 84
 85def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
 86    @functools.wraps(func)
 87    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
 88        return func(r, l)
 89
 90    return _swapped
 91
 92
 93def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
 94    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
 95
 96
 97class _TypeAnnotator(type):
 98    def __new__(cls, clsname, bases, attrs):
 99        klass = super().__new__(cls, clsname, bases, attrs)
100
101        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
102        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
103        text_precedence = (
104            exp.DataType.Type.TEXT,
105            exp.DataType.Type.NVARCHAR,
106            exp.DataType.Type.VARCHAR,
107            exp.DataType.Type.NCHAR,
108            exp.DataType.Type.CHAR,
109        )
110        numeric_precedence = (
111            exp.DataType.Type.DOUBLE,
112            exp.DataType.Type.FLOAT,
113            exp.DataType.Type.DECIMAL,
114            exp.DataType.Type.BIGINT,
115            exp.DataType.Type.INT,
116            exp.DataType.Type.SMALLINT,
117            exp.DataType.Type.TINYINT,
118        )
119        timelike_precedence = (
120            exp.DataType.Type.TIMESTAMPLTZ,
121            exp.DataType.Type.TIMESTAMPTZ,
122            exp.DataType.Type.TIMESTAMP,
123            exp.DataType.Type.DATETIME,
124            exp.DataType.Type.DATE,
125        )
126
127        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
128            coerces_to = set()
129            for data_type in type_precedence:
130                klass.COERCES_TO[data_type] = coerces_to.copy()
131                coerces_to |= {data_type}
132
133        return klass
134
135
136class TypeAnnotator(metaclass=_TypeAnnotator):
137    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
138        exp.DataType.Type.BIGINT: {
139            exp.ApproxDistinct,
140            exp.ArraySize,
141            exp.Count,
142            exp.Length,
143        },
144        exp.DataType.Type.BOOLEAN: {
145            exp.Between,
146            exp.Boolean,
147            exp.In,
148            exp.RegexpLike,
149        },
150        exp.DataType.Type.DATE: {
151            exp.CurrentDate,
152            exp.Date,
153            exp.DateFromParts,
154            exp.DateStrToDate,
155            exp.DiToDate,
156            exp.StrToDate,
157            exp.TimeStrToDate,
158            exp.TsOrDsToDate,
159        },
160        exp.DataType.Type.DATETIME: {
161            exp.CurrentDatetime,
162            exp.DatetimeAdd,
163            exp.DatetimeSub,
164        },
165        exp.DataType.Type.DOUBLE: {
166            exp.ApproxQuantile,
167            exp.Avg,
168            exp.Exp,
169            exp.Ln,
170            exp.Log,
171            exp.Log2,
172            exp.Log10,
173            exp.Pow,
174            exp.Quantile,
175            exp.Round,
176            exp.SafeDivide,
177            exp.Sqrt,
178            exp.Stddev,
179            exp.StddevPop,
180            exp.StddevSamp,
181            exp.Variance,
182            exp.VariancePop,
183        },
184        exp.DataType.Type.INT: {
185            exp.Ceil,
186            exp.DatetimeDiff,
187            exp.DateDiff,
188            exp.Extract,
189            exp.TimestampDiff,
190            exp.TimeDiff,
191            exp.DateToDi,
192            exp.Floor,
193            exp.Levenshtein,
194            exp.StrPosition,
195            exp.TsOrDiToDi,
196        },
197        exp.DataType.Type.TIMESTAMP: {
198            exp.CurrentTime,
199            exp.CurrentTimestamp,
200            exp.StrToTime,
201            exp.TimeAdd,
202            exp.TimeStrToTime,
203            exp.TimeSub,
204            exp.Timestamp,
205            exp.TimestampAdd,
206            exp.TimestampSub,
207            exp.UnixToTime,
208        },
209        exp.DataType.Type.TINYINT: {
210            exp.Day,
211            exp.Month,
212            exp.Week,
213            exp.Year,
214        },
215        exp.DataType.Type.VARCHAR: {
216            exp.ArrayConcat,
217            exp.Concat,
218            exp.ConcatWs,
219            exp.DateToDateStr,
220            exp.GroupConcat,
221            exp.Initcap,
222            exp.Lower,
223            exp.SafeConcat,
224            exp.SafeDPipe,
225            exp.Substring,
226            exp.TimeToStr,
227            exp.TimeToTimeStr,
228            exp.Trim,
229            exp.TsOrDsToDateStr,
230            exp.UnixToStr,
231            exp.UnixToTimeStr,
232            exp.Upper,
233        },
234    }
235
236    ANNOTATORS: t.Dict = {
237        **{
238            expr_type: lambda self, e: self._annotate_unary(e)
239            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
240        },
241        **{
242            expr_type: lambda self, e: self._annotate_binary(e)
243            for expr_type in subclasses(exp.__name__, exp.Binary)
244        },
245        **{
246            expr_type: _annotate_with_type_lambda(data_type)
247            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
248            for expr_type in expressions
249        },
250        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
251        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
252        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
253        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
254        exp.Bracket: lambda self, e: self._annotate_bracket(e),
255        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
256        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
257        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
258        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
259        exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
260        exp.DateSub: lambda self, e: self._annotate_timeunit(e),
261        exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
262        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
263        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
264        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
265        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
266        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
267        exp.Literal: lambda self, e: self._annotate_literal(e),
268        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
269        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
270        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
271        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
272        exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
273        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
274        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
275        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
276    }
277
278    NESTED_TYPES = {
279        exp.DataType.Type.ARRAY,
280    }
281
282    # Specifies what types a given type can be coerced into (autofilled)
283    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
284
285    # Coercion functions for binary operations.
286    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
287    BINARY_COERCIONS: BinaryCoercions = {
288        **swap_all(
289            {
290                (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
291                    l, r.args.get("unit")
292                )
293                for t in exp.DataType.TEXT_TYPES
294            }
295        ),
296        **swap_all(
297            {
298                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
299                    l, r.args.get("unit")
300                ),
301            }
302        ),
303    }
304
305    def __init__(
306        self,
307        schema: Schema,
308        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
309        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
310        binary_coercions: t.Optional[BinaryCoercions] = None,
311    ) -> None:
312        self.schema = schema
313        self.annotators = annotators or self.ANNOTATORS
314        self.coerces_to = coerces_to or self.COERCES_TO
315        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
316
317        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
318        self._visited: t.Set[int] = set()
319
320    def _set_type(
321        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
322    ) -> None:
323        expression.type = target_type  # type: ignore
324        self._visited.add(id(expression))
325
326    def annotate(self, expression: E) -> E:
327        for scope in traverse_scope(expression):
328            selects = {}
329            for name, source in scope.sources.items():
330                if not isinstance(source, Scope):
331                    continue
332                if isinstance(source.expression, exp.UDTF):
333                    values = []
334
335                    if isinstance(source.expression, exp.Lateral):
336                        if isinstance(source.expression.this, exp.Explode):
337                            values = [source.expression.this.this]
338                    else:
339                        values = source.expression.expressions[0].expressions
340
341                    if not values:
342                        continue
343
344                    selects[name] = {
345                        alias: column
346                        for alias, column in zip(
347                            source.expression.alias_column_names,
348                            values,
349                        )
350                    }
351                else:
352                    selects[name] = {
353                        select.alias_or_name: select for select in source.expression.selects
354                    }
355
356            # First annotate the current scope's column references
357            for col in scope.columns:
358                if not col.table:
359                    continue
360
361                source = scope.sources.get(col.table)
362                if isinstance(source, exp.Table):
363                    self._set_type(col, self.schema.get_column_type(source, col))
364                elif source and col.table in selects and col.name in selects[col.table]:
365                    self._set_type(col, selects[col.table][col.name].type)
366
367            # Then (possibly) annotate the remaining expressions in the scope
368            self._maybe_annotate(scope.expression)
369
370        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
371
372    def _maybe_annotate(self, expression: E) -> E:
373        if id(expression) in self._visited:
374            return expression  # We've already inferred the expression's type
375
376        annotator = self.annotators.get(expression.__class__)
377
378        return (
379            annotator(self, expression)
380            if annotator
381            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
382        )
383
384    def _annotate_args(self, expression: E) -> E:
385        for _, value in expression.iter_expressions():
386            self._maybe_annotate(value)
387
388        return expression
389
390    def _maybe_coerce(
391        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
392    ) -> exp.DataType | exp.DataType.Type:
393        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
394        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
395
396        # We propagate the NULL / UNKNOWN types upwards if found
397        if exp.DataType.Type.NULL in (type1_value, type2_value):
398            return exp.DataType.Type.NULL
399        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
400            return exp.DataType.Type.UNKNOWN
401
402        if type1_value in self.NESTED_TYPES:
403            return type1
404        if type2_value in self.NESTED_TYPES:
405            return type2
406
407        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
408
409    # Note: the following "no_type_check" decorators were added because mypy was yelling due
410    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
411    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
412
413    @t.no_type_check
414    def _annotate_binary(self, expression: B) -> B:
415        self._annotate_args(expression)
416
417        left, right = expression.left, expression.right
418        left_type, right_type = left.type.this, right.type.this
419
420        if isinstance(expression, exp.Connector):
421            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
422                self._set_type(expression, exp.DataType.Type.NULL)
423            elif exp.DataType.Type.NULL in (left_type, right_type):
424                self._set_type(
425                    expression,
426                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
427                )
428            else:
429                self._set_type(expression, exp.DataType.Type.BOOLEAN)
430        elif isinstance(expression, exp.Predicate):
431            self._set_type(expression, exp.DataType.Type.BOOLEAN)
432        elif (left_type, right_type) in self.binary_coercions:
433            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
434        else:
435            self._set_type(expression, self._maybe_coerce(left_type, right_type))
436
437        return expression
438
439    @t.no_type_check
440    def _annotate_unary(self, expression: E) -> E:
441        self._annotate_args(expression)
442
443        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
444            self._set_type(expression, exp.DataType.Type.BOOLEAN)
445        else:
446            self._set_type(expression, expression.this.type)
447
448        return expression
449
450    @t.no_type_check
451    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
452        if expression.is_string:
453            self._set_type(expression, exp.DataType.Type.VARCHAR)
454        elif expression.is_int:
455            self._set_type(expression, exp.DataType.Type.INT)
456        else:
457            self._set_type(expression, exp.DataType.Type.DOUBLE)
458
459        return expression
460
461    @t.no_type_check
462    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
463        self._set_type(expression, target_type)
464        return self._annotate_args(expression)
465
466    @t.no_type_check
467    def _annotate_by_args(
468        self, expression: E, *args: str, promote: bool = False, array: bool = False
469    ) -> E:
470        self._annotate_args(expression)
471
472        expressions: t.List[exp.Expression] = []
473        for arg in args:
474            arg_expr = expression.args.get(arg)
475            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
476
477        last_datatype = None
478        for expr in expressions:
479            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
480
481        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
482
483        if promote:
484            if expression.type.this in exp.DataType.INTEGER_TYPES:
485                self._set_type(expression, exp.DataType.Type.BIGINT)
486            elif expression.type.this in exp.DataType.FLOAT_TYPES:
487                self._set_type(expression, exp.DataType.Type.DOUBLE)
488
489        if array:
490            self._set_type(
491                expression,
492                exp.DataType(
493                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
494                ),
495            )
496
497        return expression
498
499    def _annotate_timeunit(
500        self, expression: exp.TimeUnit | exp.DateTrunc
501    ) -> exp.TimeUnit | exp.DateTrunc:
502        self._annotate_args(expression)
503
504        if expression.this.type.this in exp.DataType.TEXT_TYPES:
505            datatype = _coerce_date_literal(expression.this, expression.unit)
506        elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
507            datatype = _coerce_date(expression.this, expression.unit)
508        else:
509            datatype = exp.DataType.Type.UNKNOWN
510
511        self._set_type(expression, datatype)
512        return expression
513
514    def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
515        self._annotate_args(expression)
516
517        bracket_arg = expression.expressions[0]
518        this = expression.this
519
520        if isinstance(bracket_arg, exp.Slice):
521            self._set_type(expression, this.type)
522        elif this.type.is_type(exp.DataType.Type.ARRAY):
523            contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
524            self._set_type(expression, contained_type)
525        elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
526            index = this.keys.index(bracket_arg)
527            value = seq_get(this.values, index)
528            value_type = value.type if value else exp.DataType.Type.UNKNOWN
529            self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
530        else:
531            self._set_type(expression, exp.DataType.Type.UNKNOWN)
532
533        return expression
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
30def annotate_types(
31    expression: E,
32    schema: t.Optional[t.Dict | Schema] = None,
33    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
34    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
35) -> E:
36    """
37    Infers the types of an expression, annotating its AST accordingly.
38
39    Example:
40        >>> import sqlglot
41        >>> schema = {"y": {"cola": "SMALLINT"}}
42        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
43        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
44        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
45        <Type.DOUBLE: 'DOUBLE'>
46
47    Args:
48        expression: Expression to annotate.
49        schema: Database schema.
50        annotators: Maps expression type to corresponding annotation function.
51        coerces_to: Maps expression type to set of types that it can be coerced into.
52
53    Returns:
54        The expression annotated with types.
55    """
56
57    schema = ensure_schema(schema)
58
59    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

86def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
87    @functools.wraps(func)
88    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
89        return func(r, l)
90
91    return _swapped
94def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
95    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
class TypeAnnotator:
137class TypeAnnotator(metaclass=_TypeAnnotator):
138    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
139        exp.DataType.Type.BIGINT: {
140            exp.ApproxDistinct,
141            exp.ArraySize,
142            exp.Count,
143            exp.Length,
144        },
145        exp.DataType.Type.BOOLEAN: {
146            exp.Between,
147            exp.Boolean,
148            exp.In,
149            exp.RegexpLike,
150        },
151        exp.DataType.Type.DATE: {
152            exp.CurrentDate,
153            exp.Date,
154            exp.DateFromParts,
155            exp.DateStrToDate,
156            exp.DiToDate,
157            exp.StrToDate,
158            exp.TimeStrToDate,
159            exp.TsOrDsToDate,
160        },
161        exp.DataType.Type.DATETIME: {
162            exp.CurrentDatetime,
163            exp.DatetimeAdd,
164            exp.DatetimeSub,
165        },
166        exp.DataType.Type.DOUBLE: {
167            exp.ApproxQuantile,
168            exp.Avg,
169            exp.Exp,
170            exp.Ln,
171            exp.Log,
172            exp.Log2,
173            exp.Log10,
174            exp.Pow,
175            exp.Quantile,
176            exp.Round,
177            exp.SafeDivide,
178            exp.Sqrt,
179            exp.Stddev,
180            exp.StddevPop,
181            exp.StddevSamp,
182            exp.Variance,
183            exp.VariancePop,
184        },
185        exp.DataType.Type.INT: {
186            exp.Ceil,
187            exp.DatetimeDiff,
188            exp.DateDiff,
189            exp.Extract,
190            exp.TimestampDiff,
191            exp.TimeDiff,
192            exp.DateToDi,
193            exp.Floor,
194            exp.Levenshtein,
195            exp.StrPosition,
196            exp.TsOrDiToDi,
197        },
198        exp.DataType.Type.TIMESTAMP: {
199            exp.CurrentTime,
200            exp.CurrentTimestamp,
201            exp.StrToTime,
202            exp.TimeAdd,
203            exp.TimeStrToTime,
204            exp.TimeSub,
205            exp.Timestamp,
206            exp.TimestampAdd,
207            exp.TimestampSub,
208            exp.UnixToTime,
209        },
210        exp.DataType.Type.TINYINT: {
211            exp.Day,
212            exp.Month,
213            exp.Week,
214            exp.Year,
215        },
216        exp.DataType.Type.VARCHAR: {
217            exp.ArrayConcat,
218            exp.Concat,
219            exp.ConcatWs,
220            exp.DateToDateStr,
221            exp.GroupConcat,
222            exp.Initcap,
223            exp.Lower,
224            exp.SafeConcat,
225            exp.SafeDPipe,
226            exp.Substring,
227            exp.TimeToStr,
228            exp.TimeToTimeStr,
229            exp.Trim,
230            exp.TsOrDsToDateStr,
231            exp.UnixToStr,
232            exp.UnixToTimeStr,
233            exp.Upper,
234        },
235    }
236
237    ANNOTATORS: t.Dict = {
238        **{
239            expr_type: lambda self, e: self._annotate_unary(e)
240            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
241        },
242        **{
243            expr_type: lambda self, e: self._annotate_binary(e)
244            for expr_type in subclasses(exp.__name__, exp.Binary)
245        },
246        **{
247            expr_type: _annotate_with_type_lambda(data_type)
248            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
249            for expr_type in expressions
250        },
251        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
252        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
253        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
254        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
255        exp.Bracket: lambda self, e: self._annotate_bracket(e),
256        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
257        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
258        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
259        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
260        exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
261        exp.DateSub: lambda self, e: self._annotate_timeunit(e),
262        exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
263        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
264        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
265        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
266        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
267        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
268        exp.Literal: lambda self, e: self._annotate_literal(e),
269        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
270        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
271        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
272        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
273        exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
274        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
275        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
276        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
277    }
278
279    NESTED_TYPES = {
280        exp.DataType.Type.ARRAY,
281    }
282
283    # Specifies what types a given type can be coerced into (autofilled)
284    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
285
286    # Coercion functions for binary operations.
287    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
288    BINARY_COERCIONS: BinaryCoercions = {
289        **swap_all(
290            {
291                (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
292                    l, r.args.get("unit")
293                )
294                for t in exp.DataType.TEXT_TYPES
295            }
296        ),
297        **swap_all(
298            {
299                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
300                    l, r.args.get("unit")
301                ),
302            }
303        ),
304    }
305
306    def __init__(
307        self,
308        schema: Schema,
309        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
310        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
311        binary_coercions: t.Optional[BinaryCoercions] = None,
312    ) -> None:
313        self.schema = schema
314        self.annotators = annotators or self.ANNOTATORS
315        self.coerces_to = coerces_to or self.COERCES_TO
316        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
317
318        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
319        self._visited: t.Set[int] = set()
320
321    def _set_type(
322        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
323    ) -> None:
324        expression.type = target_type  # type: ignore
325        self._visited.add(id(expression))
326
327    def annotate(self, expression: E) -> E:
328        for scope in traverse_scope(expression):
329            selects = {}
330            for name, source in scope.sources.items():
331                if not isinstance(source, Scope):
332                    continue
333                if isinstance(source.expression, exp.UDTF):
334                    values = []
335
336                    if isinstance(source.expression, exp.Lateral):
337                        if isinstance(source.expression.this, exp.Explode):
338                            values = [source.expression.this.this]
339                    else:
340                        values = source.expression.expressions[0].expressions
341
342                    if not values:
343                        continue
344
345                    selects[name] = {
346                        alias: column
347                        for alias, column in zip(
348                            source.expression.alias_column_names,
349                            values,
350                        )
351                    }
352                else:
353                    selects[name] = {
354                        select.alias_or_name: select for select in source.expression.selects
355                    }
356
357            # First annotate the current scope's column references
358            for col in scope.columns:
359                if not col.table:
360                    continue
361
362                source = scope.sources.get(col.table)
363                if isinstance(source, exp.Table):
364                    self._set_type(col, self.schema.get_column_type(source, col))
365                elif source and col.table in selects and col.name in selects[col.table]:
366                    self._set_type(col, selects[col.table][col.name].type)
367
368            # Then (possibly) annotate the remaining expressions in the scope
369            self._maybe_annotate(scope.expression)
370
371        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
372
373    def _maybe_annotate(self, expression: E) -> E:
374        if id(expression) in self._visited:
375            return expression  # We've already inferred the expression's type
376
377        annotator = self.annotators.get(expression.__class__)
378
379        return (
380            annotator(self, expression)
381            if annotator
382            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
383        )
384
385    def _annotate_args(self, expression: E) -> E:
386        for _, value in expression.iter_expressions():
387            self._maybe_annotate(value)
388
389        return expression
390
391    def _maybe_coerce(
392        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
393    ) -> exp.DataType | exp.DataType.Type:
394        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
395        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
396
397        # We propagate the NULL / UNKNOWN types upwards if found
398        if exp.DataType.Type.NULL in (type1_value, type2_value):
399            return exp.DataType.Type.NULL
400        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
401            return exp.DataType.Type.UNKNOWN
402
403        if type1_value in self.NESTED_TYPES:
404            return type1
405        if type2_value in self.NESTED_TYPES:
406            return type2
407
408        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
409
410    # Note: the following "no_type_check" decorators were added because mypy was yelling due
411    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
412    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
413
414    @t.no_type_check
415    def _annotate_binary(self, expression: B) -> B:
416        self._annotate_args(expression)
417
418        left, right = expression.left, expression.right
419        left_type, right_type = left.type.this, right.type.this
420
421        if isinstance(expression, exp.Connector):
422            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
423                self._set_type(expression, exp.DataType.Type.NULL)
424            elif exp.DataType.Type.NULL in (left_type, right_type):
425                self._set_type(
426                    expression,
427                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
428                )
429            else:
430                self._set_type(expression, exp.DataType.Type.BOOLEAN)
431        elif isinstance(expression, exp.Predicate):
432            self._set_type(expression, exp.DataType.Type.BOOLEAN)
433        elif (left_type, right_type) in self.binary_coercions:
434            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
435        else:
436            self._set_type(expression, self._maybe_coerce(left_type, right_type))
437
438        return expression
439
440    @t.no_type_check
441    def _annotate_unary(self, expression: E) -> E:
442        self._annotate_args(expression)
443
444        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
445            self._set_type(expression, exp.DataType.Type.BOOLEAN)
446        else:
447            self._set_type(expression, expression.this.type)
448
449        return expression
450
451    @t.no_type_check
452    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
453        if expression.is_string:
454            self._set_type(expression, exp.DataType.Type.VARCHAR)
455        elif expression.is_int:
456            self._set_type(expression, exp.DataType.Type.INT)
457        else:
458            self._set_type(expression, exp.DataType.Type.DOUBLE)
459
460        return expression
461
462    @t.no_type_check
463    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
464        self._set_type(expression, target_type)
465        return self._annotate_args(expression)
466
467    @t.no_type_check
468    def _annotate_by_args(
469        self, expression: E, *args: str, promote: bool = False, array: bool = False
470    ) -> E:
471        self._annotate_args(expression)
472
473        expressions: t.List[exp.Expression] = []
474        for arg in args:
475            arg_expr = expression.args.get(arg)
476            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
477
478        last_datatype = None
479        for expr in expressions:
480            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
481
482        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
483
484        if promote:
485            if expression.type.this in exp.DataType.INTEGER_TYPES:
486                self._set_type(expression, exp.DataType.Type.BIGINT)
487            elif expression.type.this in exp.DataType.FLOAT_TYPES:
488                self._set_type(expression, exp.DataType.Type.DOUBLE)
489
490        if array:
491            self._set_type(
492                expression,
493                exp.DataType(
494                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
495                ),
496            )
497
498        return expression
499
500    def _annotate_timeunit(
501        self, expression: exp.TimeUnit | exp.DateTrunc
502    ) -> exp.TimeUnit | exp.DateTrunc:
503        self._annotate_args(expression)
504
505        if expression.this.type.this in exp.DataType.TEXT_TYPES:
506            datatype = _coerce_date_literal(expression.this, expression.unit)
507        elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
508            datatype = _coerce_date(expression.this, expression.unit)
509        else:
510            datatype = exp.DataType.Type.UNKNOWN
511
512        self._set_type(expression, datatype)
513        return expression
514
515    def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
516        self._annotate_args(expression)
517
518        bracket_arg = expression.expressions[0]
519        this = expression.this
520
521        if isinstance(bracket_arg, exp.Slice):
522            self._set_type(expression, this.type)
523        elif this.type.is_type(exp.DataType.Type.ARRAY):
524            contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
525            self._set_type(expression, contained_type)
526        elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
527            index = this.keys.index(bracket_arg)
528            value = seq_get(this.values, index)
529            value_type = value.type if value else exp.DataType.Type.UNKNOWN
530            self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
531        else:
532            self._set_type(expression, exp.DataType.Type.UNKNOWN)
533
534        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
306    def __init__(
307        self,
308        schema: Schema,
309        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
310        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
311        binary_coercions: t.Optional[BinaryCoercions] = None,
312    ) -> None:
313        self.schema = schema
314        self.annotators = annotators or self.ANNOTATORS
315        self.coerces_to = coerces_to or self.COERCES_TO
316        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
317
318        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
319        self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ArraySize'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.RegexpLike'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.Date'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Log10'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DateToDi'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeAdd'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Day'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.Trim'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Bracket'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.NCHAR: 'NCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.SMALLINT: 'SMALLINT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.INT: 'INT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] = {(<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function TypeAnnotator.<lambda>>}
schema
annotators
coerces_to
binary_coercions
def annotate(self, expression: ~E) -> ~E:
327    def annotate(self, expression: E) -> E:
328        for scope in traverse_scope(expression):
329            selects = {}
330            for name, source in scope.sources.items():
331                if not isinstance(source, Scope):
332                    continue
333                if isinstance(source.expression, exp.UDTF):
334                    values = []
335
336                    if isinstance(source.expression, exp.Lateral):
337                        if isinstance(source.expression.this, exp.Explode):
338                            values = [source.expression.this.this]
339                    else:
340                        values = source.expression.expressions[0].expressions
341
342                    if not values:
343                        continue
344
345                    selects[name] = {
346                        alias: column
347                        for alias, column in zip(
348                            source.expression.alias_column_names,
349                            values,
350                        )
351                    }
352                else:
353                    selects[name] = {
354                        select.alias_or_name: select for select in source.expression.selects
355                    }
356
357            # First annotate the current scope's column references
358            for col in scope.columns:
359                if not col.table:
360                    continue
361
362                source = scope.sources.get(col.table)
363                if isinstance(source, exp.Table):
364                    self._set_type(col, self.schema.get_column_type(source, col))
365                elif source and col.table in selects and col.name in selects[col.table]:
366                    self._set_type(col, selects[col.table][col.name].type)
367
368            # Then (possibly) annotate the remaining expressions in the scope
369            self._maybe_annotate(scope.expression)
370
371        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions