Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import functools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot._typing import E
  8from sqlglot.helper import (
  9    ensure_list,
 10    is_date_unit,
 11    is_iso_date,
 12    is_iso_datetime,
 13    seq_get,
 14    subclasses,
 15)
 16from sqlglot.optimizer.scope import Scope, traverse_scope
 17from sqlglot.schema import Schema, ensure_schema
 18
 19if t.TYPE_CHECKING:
 20    B = t.TypeVar("B", bound=exp.Binary)
 21
 22    BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
 23    BinaryCoercions = t.Dict[
 24        t.Tuple[exp.DataType.Type, exp.DataType.Type],
 25        BinaryCoercionFunc,
 26    ]
 27
 28
 29def annotate_types(
 30    expression: E,
 31    schema: t.Optional[t.Dict | Schema] = None,
 32    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 33    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 34) -> E:
 35    """
 36    Infers the types of an expression, annotating its AST accordingly.
 37
 38    Example:
 39        >>> import sqlglot
 40        >>> schema = {"y": {"cola": "SMALLINT"}}
 41        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 42        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 43        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 44        <Type.DOUBLE: 'DOUBLE'>
 45
 46    Args:
 47        expression: Expression to annotate.
 48        schema: Database schema.
 49        annotators: Maps expression type to corresponding annotation function.
 50        coerces_to: Maps expression type to set of types that it can be coerced into.
 51
 52    Returns:
 53        The expression annotated with types.
 54    """
 55
 56    schema = ensure_schema(schema)
 57
 58    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 59
 60
 61def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 62    return lambda self, e: self._annotate_with_type(e, data_type)
 63
 64
 65def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
 66    date_text = l.name
 67    is_iso_date_ = is_iso_date(date_text)
 68
 69    if is_iso_date_ and is_date_unit(unit):
 70        return exp.DataType.Type.DATE
 71
 72    # An ISO date is also an ISO datetime, but not vice versa
 73    if is_iso_date_ or is_iso_datetime(date_text):
 74        return exp.DataType.Type.DATETIME
 75
 76    return exp.DataType.Type.UNKNOWN
 77
 78
 79def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
 80    if not is_date_unit(unit):
 81        return exp.DataType.Type.DATETIME
 82    return l.type.this if l.type else exp.DataType.Type.UNKNOWN
 83
 84
 85def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
 86    @functools.wraps(func)
 87    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
 88        return func(r, l)
 89
 90    return _swapped
 91
 92
 93def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
 94    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
 95
 96
 97class _TypeAnnotator(type):
 98    def __new__(cls, clsname, bases, attrs):
 99        klass = super().__new__(cls, clsname, bases, attrs)
100
101        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
102        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
103        text_precedence = (
104            exp.DataType.Type.TEXT,
105            exp.DataType.Type.NVARCHAR,
106            exp.DataType.Type.VARCHAR,
107            exp.DataType.Type.NCHAR,
108            exp.DataType.Type.CHAR,
109        )
110        numeric_precedence = (
111            exp.DataType.Type.DOUBLE,
112            exp.DataType.Type.FLOAT,
113            exp.DataType.Type.DECIMAL,
114            exp.DataType.Type.BIGINT,
115            exp.DataType.Type.INT,
116            exp.DataType.Type.SMALLINT,
117            exp.DataType.Type.TINYINT,
118        )
119        timelike_precedence = (
120            exp.DataType.Type.TIMESTAMPLTZ,
121            exp.DataType.Type.TIMESTAMPTZ,
122            exp.DataType.Type.TIMESTAMP,
123            exp.DataType.Type.DATETIME,
124            exp.DataType.Type.DATE,
125        )
126
127        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
128            coerces_to = set()
129            for data_type in type_precedence:
130                klass.COERCES_TO[data_type] = coerces_to.copy()
131                coerces_to |= {data_type}
132
133        return klass
134
135
136class TypeAnnotator(metaclass=_TypeAnnotator):
137    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
138        exp.DataType.Type.BIGINT: {
139            exp.ApproxDistinct,
140            exp.ArraySize,
141            exp.Count,
142            exp.Length,
143        },
144        exp.DataType.Type.BOOLEAN: {
145            exp.Between,
146            exp.Boolean,
147            exp.In,
148            exp.RegexpLike,
149        },
150        exp.DataType.Type.DATE: {
151            exp.CurrentDate,
152            exp.Date,
153            exp.DateFromParts,
154            exp.DateStrToDate,
155            exp.DiToDate,
156            exp.StrToDate,
157            exp.TimeStrToDate,
158            exp.TsOrDsToDate,
159        },
160        exp.DataType.Type.DATETIME: {
161            exp.CurrentDatetime,
162            exp.DatetimeAdd,
163            exp.DatetimeSub,
164        },
165        exp.DataType.Type.DOUBLE: {
166            exp.ApproxQuantile,
167            exp.Avg,
168            exp.Div,
169            exp.Exp,
170            exp.Ln,
171            exp.Log,
172            exp.Log2,
173            exp.Log10,
174            exp.Pow,
175            exp.Quantile,
176            exp.Round,
177            exp.SafeDivide,
178            exp.Sqrt,
179            exp.Stddev,
180            exp.StddevPop,
181            exp.StddevSamp,
182            exp.Variance,
183            exp.VariancePop,
184        },
185        exp.DataType.Type.INT: {
186            exp.Ceil,
187            exp.DatetimeDiff,
188            exp.DateDiff,
189            exp.Extract,
190            exp.TimestampDiff,
191            exp.TimeDiff,
192            exp.DateToDi,
193            exp.Floor,
194            exp.Levenshtein,
195            exp.StrPosition,
196            exp.TsOrDiToDi,
197        },
198        exp.DataType.Type.TIMESTAMP: {
199            exp.CurrentTime,
200            exp.CurrentTimestamp,
201            exp.StrToTime,
202            exp.TimeAdd,
203            exp.TimeStrToTime,
204            exp.TimeSub,
205            exp.Timestamp,
206            exp.TimestampAdd,
207            exp.TimestampSub,
208            exp.UnixToTime,
209        },
210        exp.DataType.Type.TINYINT: {
211            exp.Day,
212            exp.Month,
213            exp.Week,
214            exp.Year,
215        },
216        exp.DataType.Type.VARCHAR: {
217            exp.ArrayConcat,
218            exp.Concat,
219            exp.ConcatWs,
220            exp.DateToDateStr,
221            exp.GroupConcat,
222            exp.Initcap,
223            exp.Lower,
224            exp.SafeConcat,
225            exp.SafeDPipe,
226            exp.Substring,
227            exp.TimeToStr,
228            exp.TimeToTimeStr,
229            exp.Trim,
230            exp.TsOrDsToDateStr,
231            exp.UnixToStr,
232            exp.UnixToTimeStr,
233            exp.Upper,
234        },
235    }
236
237    ANNOTATORS: t.Dict = {
238        **{
239            expr_type: lambda self, e: self._annotate_unary(e)
240            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
241        },
242        **{
243            expr_type: lambda self, e: self._annotate_binary(e)
244            for expr_type in subclasses(exp.__name__, exp.Binary)
245        },
246        **{
247            expr_type: _annotate_with_type_lambda(data_type)
248            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
249            for expr_type in expressions
250        },
251        exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
252        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
253        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
254        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
255        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
256        exp.Bracket: lambda self, e: self._annotate_bracket(e),
257        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
258        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
259        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
260        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
261        exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
262        exp.DateSub: lambda self, e: self._annotate_timeunit(e),
263        exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
264        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
265        exp.Div: lambda self, e: self._annotate_div(e),
266        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
267        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
268        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
269        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
270        exp.Literal: lambda self, e: self._annotate_literal(e),
271        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
272        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
273        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
274        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
275        exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
276        exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
277        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
278        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
279        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
280    }
281
282    NESTED_TYPES = {
283        exp.DataType.Type.ARRAY,
284    }
285
286    # Specifies what types a given type can be coerced into (autofilled)
287    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
288
289    # Coercion functions for binary operations.
290    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
291    BINARY_COERCIONS: BinaryCoercions = {
292        **swap_all(
293            {
294                (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
295                    l, r.args.get("unit")
296                )
297                for t in exp.DataType.TEXT_TYPES
298            }
299        ),
300        **swap_all(
301            {
302                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
303                    l, r.args.get("unit")
304                ),
305            }
306        ),
307    }
308
309    def __init__(
310        self,
311        schema: Schema,
312        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
313        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
314        binary_coercions: t.Optional[BinaryCoercions] = None,
315    ) -> None:
316        self.schema = schema
317        self.annotators = annotators or self.ANNOTATORS
318        self.coerces_to = coerces_to or self.COERCES_TO
319        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
320
321        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
322        self._visited: t.Set[int] = set()
323
324    def _set_type(
325        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
326    ) -> None:
327        expression.type = target_type  # type: ignore
328        self._visited.add(id(expression))
329
330    def annotate(self, expression: E) -> E:
331        for scope in traverse_scope(expression):
332            selects = {}
333            for name, source in scope.sources.items():
334                if not isinstance(source, Scope):
335                    continue
336                if isinstance(source.expression, exp.UDTF):
337                    values = []
338
339                    if isinstance(source.expression, exp.Lateral):
340                        if isinstance(source.expression.this, exp.Explode):
341                            values = [source.expression.this.this]
342                    else:
343                        values = source.expression.expressions[0].expressions
344
345                    if not values:
346                        continue
347
348                    selects[name] = {
349                        alias: column
350                        for alias, column in zip(
351                            source.expression.alias_column_names,
352                            values,
353                        )
354                    }
355                else:
356                    selects[name] = {
357                        select.alias_or_name: select for select in source.expression.selects
358                    }
359
360            # First annotate the current scope's column references
361            for col in scope.columns:
362                if not col.table:
363                    continue
364
365                source = scope.sources.get(col.table)
366                if isinstance(source, exp.Table):
367                    self._set_type(col, self.schema.get_column_type(source, col))
368                elif source and col.table in selects and col.name in selects[col.table]:
369                    self._set_type(col, selects[col.table][col.name].type)
370
371            # Then (possibly) annotate the remaining expressions in the scope
372            self._maybe_annotate(scope.expression)
373
374        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
375
376    def _maybe_annotate(self, expression: E) -> E:
377        if id(expression) in self._visited:
378            return expression  # We've already inferred the expression's type
379
380        annotator = self.annotators.get(expression.__class__)
381
382        return (
383            annotator(self, expression)
384            if annotator
385            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
386        )
387
388    def _annotate_args(self, expression: E) -> E:
389        for _, value in expression.iter_expressions():
390            self._maybe_annotate(value)
391
392        return expression
393
394    def _maybe_coerce(
395        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
396    ) -> exp.DataType | exp.DataType.Type:
397        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
398        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
399
400        # We propagate the NULL / UNKNOWN types upwards if found
401        if exp.DataType.Type.NULL in (type1_value, type2_value):
402            return exp.DataType.Type.NULL
403        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
404            return exp.DataType.Type.UNKNOWN
405
406        if type1_value in self.NESTED_TYPES:
407            return type1
408        if type2_value in self.NESTED_TYPES:
409            return type2
410
411        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
412
413    # Note: the following "no_type_check" decorators were added because mypy was yelling due
414    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
415    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
416
417    @t.no_type_check
418    def _annotate_binary(self, expression: B) -> B:
419        self._annotate_args(expression)
420
421        left, right = expression.left, expression.right
422        left_type, right_type = left.type.this, right.type.this
423
424        if isinstance(expression, exp.Connector):
425            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
426                self._set_type(expression, exp.DataType.Type.NULL)
427            elif exp.DataType.Type.NULL in (left_type, right_type):
428                self._set_type(
429                    expression,
430                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
431                )
432            else:
433                self._set_type(expression, exp.DataType.Type.BOOLEAN)
434        elif isinstance(expression, exp.Predicate):
435            self._set_type(expression, exp.DataType.Type.BOOLEAN)
436        elif (left_type, right_type) in self.binary_coercions:
437            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
438        else:
439            self._set_type(expression, self._maybe_coerce(left_type, right_type))
440
441        return expression
442
443    @t.no_type_check
444    def _annotate_unary(self, expression: E) -> E:
445        self._annotate_args(expression)
446
447        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
448            self._set_type(expression, exp.DataType.Type.BOOLEAN)
449        else:
450            self._set_type(expression, expression.this.type)
451
452        return expression
453
454    @t.no_type_check
455    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
456        if expression.is_string:
457            self._set_type(expression, exp.DataType.Type.VARCHAR)
458        elif expression.is_int:
459            self._set_type(expression, exp.DataType.Type.INT)
460        else:
461            self._set_type(expression, exp.DataType.Type.DOUBLE)
462
463        return expression
464
465    @t.no_type_check
466    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
467        self._set_type(expression, target_type)
468        return self._annotate_args(expression)
469
470    @t.no_type_check
471    def _annotate_by_args(
472        self, expression: E, *args: str, promote: bool = False, array: bool = False
473    ) -> E:
474        self._annotate_args(expression)
475
476        expressions: t.List[exp.Expression] = []
477        for arg in args:
478            arg_expr = expression.args.get(arg)
479            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
480
481        last_datatype = None
482        for expr in expressions:
483            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
484
485        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
486
487        if promote:
488            if expression.type.this in exp.DataType.INTEGER_TYPES:
489                self._set_type(expression, exp.DataType.Type.BIGINT)
490            elif expression.type.this in exp.DataType.FLOAT_TYPES:
491                self._set_type(expression, exp.DataType.Type.DOUBLE)
492
493        if array:
494            self._set_type(
495                expression,
496                exp.DataType(
497                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
498                ),
499            )
500
501        return expression
502
503    def _annotate_timeunit(
504        self, expression: exp.TimeUnit | exp.DateTrunc
505    ) -> exp.TimeUnit | exp.DateTrunc:
506        self._annotate_args(expression)
507
508        if expression.this.type.this in exp.DataType.TEXT_TYPES:
509            datatype = _coerce_date_literal(expression.this, expression.unit)
510        elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
511            datatype = _coerce_date(expression.this, expression.unit)
512        else:
513            datatype = exp.DataType.Type.UNKNOWN
514
515        self._set_type(expression, datatype)
516        return expression
517
518    def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
519        self._annotate_args(expression)
520
521        bracket_arg = expression.expressions[0]
522        this = expression.this
523
524        if isinstance(bracket_arg, exp.Slice):
525            self._set_type(expression, this.type)
526        elif this.type.is_type(exp.DataType.Type.ARRAY):
527            contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
528            self._set_type(expression, contained_type)
529        elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
530            index = this.keys.index(bracket_arg)
531            value = seq_get(this.values, index)
532            value_type = value.type if value else exp.DataType.Type.UNKNOWN
533            self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
534        else:
535            self._set_type(expression, exp.DataType.Type.UNKNOWN)
536
537        return expression
538
539    def _annotate_div(self, expression: exp.Div) -> exp.Div:
540        self._annotate_args(expression)
541
542        left_type, right_type = expression.left.type.this, expression.right.type.this  # type: ignore
543
544        if (
545            expression.args.get("typed")
546            and left_type in exp.DataType.INTEGER_TYPES
547            and right_type in exp.DataType.INTEGER_TYPES
548        ):
549            self._set_type(expression, exp.DataType.Type.BIGINT)
550        else:
551            self._set_type(expression, self._maybe_coerce(left_type, right_type))
552
553        return expression
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
30def annotate_types(
31    expression: E,
32    schema: t.Optional[t.Dict | Schema] = None,
33    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
34    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
35) -> E:
36    """
37    Infers the types of an expression, annotating its AST accordingly.
38
39    Example:
40        >>> import sqlglot
41        >>> schema = {"y": {"cola": "SMALLINT"}}
42        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
43        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
44        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
45        <Type.DOUBLE: 'DOUBLE'>
46
47    Args:
48        expression: Expression to annotate.
49        schema: Database schema.
50        annotators: Maps expression type to corresponding annotation function.
51        coerces_to: Maps expression type to set of types that it can be coerced into.
52
53    Returns:
54        The expression annotated with types.
55    """
56
57    schema = ensure_schema(schema)
58
59    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

86def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
87    @functools.wraps(func)
88    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
89        return func(r, l)
90
91    return _swapped
94def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
95    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
class TypeAnnotator:
137class TypeAnnotator(metaclass=_TypeAnnotator):
138    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
139        exp.DataType.Type.BIGINT: {
140            exp.ApproxDistinct,
141            exp.ArraySize,
142            exp.Count,
143            exp.Length,
144        },
145        exp.DataType.Type.BOOLEAN: {
146            exp.Between,
147            exp.Boolean,
148            exp.In,
149            exp.RegexpLike,
150        },
151        exp.DataType.Type.DATE: {
152            exp.CurrentDate,
153            exp.Date,
154            exp.DateFromParts,
155            exp.DateStrToDate,
156            exp.DiToDate,
157            exp.StrToDate,
158            exp.TimeStrToDate,
159            exp.TsOrDsToDate,
160        },
161        exp.DataType.Type.DATETIME: {
162            exp.CurrentDatetime,
163            exp.DatetimeAdd,
164            exp.DatetimeSub,
165        },
166        exp.DataType.Type.DOUBLE: {
167            exp.ApproxQuantile,
168            exp.Avg,
169            exp.Div,
170            exp.Exp,
171            exp.Ln,
172            exp.Log,
173            exp.Log2,
174            exp.Log10,
175            exp.Pow,
176            exp.Quantile,
177            exp.Round,
178            exp.SafeDivide,
179            exp.Sqrt,
180            exp.Stddev,
181            exp.StddevPop,
182            exp.StddevSamp,
183            exp.Variance,
184            exp.VariancePop,
185        },
186        exp.DataType.Type.INT: {
187            exp.Ceil,
188            exp.DatetimeDiff,
189            exp.DateDiff,
190            exp.Extract,
191            exp.TimestampDiff,
192            exp.TimeDiff,
193            exp.DateToDi,
194            exp.Floor,
195            exp.Levenshtein,
196            exp.StrPosition,
197            exp.TsOrDiToDi,
198        },
199        exp.DataType.Type.TIMESTAMP: {
200            exp.CurrentTime,
201            exp.CurrentTimestamp,
202            exp.StrToTime,
203            exp.TimeAdd,
204            exp.TimeStrToTime,
205            exp.TimeSub,
206            exp.Timestamp,
207            exp.TimestampAdd,
208            exp.TimestampSub,
209            exp.UnixToTime,
210        },
211        exp.DataType.Type.TINYINT: {
212            exp.Day,
213            exp.Month,
214            exp.Week,
215            exp.Year,
216        },
217        exp.DataType.Type.VARCHAR: {
218            exp.ArrayConcat,
219            exp.Concat,
220            exp.ConcatWs,
221            exp.DateToDateStr,
222            exp.GroupConcat,
223            exp.Initcap,
224            exp.Lower,
225            exp.SafeConcat,
226            exp.SafeDPipe,
227            exp.Substring,
228            exp.TimeToStr,
229            exp.TimeToTimeStr,
230            exp.Trim,
231            exp.TsOrDsToDateStr,
232            exp.UnixToStr,
233            exp.UnixToTimeStr,
234            exp.Upper,
235        },
236    }
237
238    ANNOTATORS: t.Dict = {
239        **{
240            expr_type: lambda self, e: self._annotate_unary(e)
241            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
242        },
243        **{
244            expr_type: lambda self, e: self._annotate_binary(e)
245            for expr_type in subclasses(exp.__name__, exp.Binary)
246        },
247        **{
248            expr_type: _annotate_with_type_lambda(data_type)
249            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
250            for expr_type in expressions
251        },
252        exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
253        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
254        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
255        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
256        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
257        exp.Bracket: lambda self, e: self._annotate_bracket(e),
258        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
259        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
260        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
261        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
262        exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
263        exp.DateSub: lambda self, e: self._annotate_timeunit(e),
264        exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
265        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
266        exp.Div: lambda self, e: self._annotate_div(e),
267        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
268        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
269        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
270        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
271        exp.Literal: lambda self, e: self._annotate_literal(e),
272        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
273        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
274        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
275        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
276        exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
277        exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
278        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
279        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
280        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
281    }
282
283    NESTED_TYPES = {
284        exp.DataType.Type.ARRAY,
285    }
286
287    # Specifies what types a given type can be coerced into (autofilled)
288    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
289
290    # Coercion functions for binary operations.
291    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
292    BINARY_COERCIONS: BinaryCoercions = {
293        **swap_all(
294            {
295                (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
296                    l, r.args.get("unit")
297                )
298                for t in exp.DataType.TEXT_TYPES
299            }
300        ),
301        **swap_all(
302            {
303                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
304                    l, r.args.get("unit")
305                ),
306            }
307        ),
308    }
309
310    def __init__(
311        self,
312        schema: Schema,
313        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
314        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
315        binary_coercions: t.Optional[BinaryCoercions] = None,
316    ) -> None:
317        self.schema = schema
318        self.annotators = annotators or self.ANNOTATORS
319        self.coerces_to = coerces_to or self.COERCES_TO
320        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
321
322        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
323        self._visited: t.Set[int] = set()
324
325    def _set_type(
326        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
327    ) -> None:
328        expression.type = target_type  # type: ignore
329        self._visited.add(id(expression))
330
331    def annotate(self, expression: E) -> E:
332        for scope in traverse_scope(expression):
333            selects = {}
334            for name, source in scope.sources.items():
335                if not isinstance(source, Scope):
336                    continue
337                if isinstance(source.expression, exp.UDTF):
338                    values = []
339
340                    if isinstance(source.expression, exp.Lateral):
341                        if isinstance(source.expression.this, exp.Explode):
342                            values = [source.expression.this.this]
343                    else:
344                        values = source.expression.expressions[0].expressions
345
346                    if not values:
347                        continue
348
349                    selects[name] = {
350                        alias: column
351                        for alias, column in zip(
352                            source.expression.alias_column_names,
353                            values,
354                        )
355                    }
356                else:
357                    selects[name] = {
358                        select.alias_or_name: select for select in source.expression.selects
359                    }
360
361            # First annotate the current scope's column references
362            for col in scope.columns:
363                if not col.table:
364                    continue
365
366                source = scope.sources.get(col.table)
367                if isinstance(source, exp.Table):
368                    self._set_type(col, self.schema.get_column_type(source, col))
369                elif source and col.table in selects and col.name in selects[col.table]:
370                    self._set_type(col, selects[col.table][col.name].type)
371
372            # Then (possibly) annotate the remaining expressions in the scope
373            self._maybe_annotate(scope.expression)
374
375        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
376
377    def _maybe_annotate(self, expression: E) -> E:
378        if id(expression) in self._visited:
379            return expression  # We've already inferred the expression's type
380
381        annotator = self.annotators.get(expression.__class__)
382
383        return (
384            annotator(self, expression)
385            if annotator
386            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
387        )
388
389    def _annotate_args(self, expression: E) -> E:
390        for _, value in expression.iter_expressions():
391            self._maybe_annotate(value)
392
393        return expression
394
395    def _maybe_coerce(
396        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
397    ) -> exp.DataType | exp.DataType.Type:
398        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
399        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
400
401        # We propagate the NULL / UNKNOWN types upwards if found
402        if exp.DataType.Type.NULL in (type1_value, type2_value):
403            return exp.DataType.Type.NULL
404        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
405            return exp.DataType.Type.UNKNOWN
406
407        if type1_value in self.NESTED_TYPES:
408            return type1
409        if type2_value in self.NESTED_TYPES:
410            return type2
411
412        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
413
414    # Note: the following "no_type_check" decorators were added because mypy was yelling due
415    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
416    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
417
418    @t.no_type_check
419    def _annotate_binary(self, expression: B) -> B:
420        self._annotate_args(expression)
421
422        left, right = expression.left, expression.right
423        left_type, right_type = left.type.this, right.type.this
424
425        if isinstance(expression, exp.Connector):
426            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
427                self._set_type(expression, exp.DataType.Type.NULL)
428            elif exp.DataType.Type.NULL in (left_type, right_type):
429                self._set_type(
430                    expression,
431                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
432                )
433            else:
434                self._set_type(expression, exp.DataType.Type.BOOLEAN)
435        elif isinstance(expression, exp.Predicate):
436            self._set_type(expression, exp.DataType.Type.BOOLEAN)
437        elif (left_type, right_type) in self.binary_coercions:
438            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
439        else:
440            self._set_type(expression, self._maybe_coerce(left_type, right_type))
441
442        return expression
443
444    @t.no_type_check
445    def _annotate_unary(self, expression: E) -> E:
446        self._annotate_args(expression)
447
448        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
449            self._set_type(expression, exp.DataType.Type.BOOLEAN)
450        else:
451            self._set_type(expression, expression.this.type)
452
453        return expression
454
455    @t.no_type_check
456    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
457        if expression.is_string:
458            self._set_type(expression, exp.DataType.Type.VARCHAR)
459        elif expression.is_int:
460            self._set_type(expression, exp.DataType.Type.INT)
461        else:
462            self._set_type(expression, exp.DataType.Type.DOUBLE)
463
464        return expression
465
466    @t.no_type_check
467    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
468        self._set_type(expression, target_type)
469        return self._annotate_args(expression)
470
471    @t.no_type_check
472    def _annotate_by_args(
473        self, expression: E, *args: str, promote: bool = False, array: bool = False
474    ) -> E:
475        self._annotate_args(expression)
476
477        expressions: t.List[exp.Expression] = []
478        for arg in args:
479            arg_expr = expression.args.get(arg)
480            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
481
482        last_datatype = None
483        for expr in expressions:
484            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
485
486        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
487
488        if promote:
489            if expression.type.this in exp.DataType.INTEGER_TYPES:
490                self._set_type(expression, exp.DataType.Type.BIGINT)
491            elif expression.type.this in exp.DataType.FLOAT_TYPES:
492                self._set_type(expression, exp.DataType.Type.DOUBLE)
493
494        if array:
495            self._set_type(
496                expression,
497                exp.DataType(
498                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
499                ),
500            )
501
502        return expression
503
504    def _annotate_timeunit(
505        self, expression: exp.TimeUnit | exp.DateTrunc
506    ) -> exp.TimeUnit | exp.DateTrunc:
507        self._annotate_args(expression)
508
509        if expression.this.type.this in exp.DataType.TEXT_TYPES:
510            datatype = _coerce_date_literal(expression.this, expression.unit)
511        elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
512            datatype = _coerce_date(expression.this, expression.unit)
513        else:
514            datatype = exp.DataType.Type.UNKNOWN
515
516        self._set_type(expression, datatype)
517        return expression
518
519    def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
520        self._annotate_args(expression)
521
522        bracket_arg = expression.expressions[0]
523        this = expression.this
524
525        if isinstance(bracket_arg, exp.Slice):
526            self._set_type(expression, this.type)
527        elif this.type.is_type(exp.DataType.Type.ARRAY):
528            contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
529            self._set_type(expression, contained_type)
530        elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
531            index = this.keys.index(bracket_arg)
532            value = seq_get(this.values, index)
533            value_type = value.type if value else exp.DataType.Type.UNKNOWN
534            self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
535        else:
536            self._set_type(expression, exp.DataType.Type.UNKNOWN)
537
538        return expression
539
540    def _annotate_div(self, expression: exp.Div) -> exp.Div:
541        self._annotate_args(expression)
542
543        left_type, right_type = expression.left.type.this, expression.right.type.this  # type: ignore
544
545        if (
546            expression.args.get("typed")
547            and left_type in exp.DataType.INTEGER_TYPES
548            and right_type in exp.DataType.INTEGER_TYPES
549        ):
550            self._set_type(expression, exp.DataType.Type.BIGINT)
551        else:
552            self._set_type(expression, self._maybe_coerce(left_type, right_type))
553
554        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
310    def __init__(
311        self,
312        schema: Schema,
313        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
314        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
315        binary_coercions: t.Optional[BinaryCoercions] = None,
316    ) -> None:
317        self.schema = schema
318        self.annotators = annotators or self.ANNOTATORS
319        self.coerces_to = coerces_to or self.COERCES_TO
320        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
321
322        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
323        self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.ApproxDistinct'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Boolean'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.DiToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Div'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.SafeDivide'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.Levenshtein'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Week'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.UnixToStr'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.PropertyEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Abs'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Bracket'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Nullif'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.NCHAR: 'NCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NCHAR: 'NCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.SMALLINT: 'SMALLINT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] = {(<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function TypeAnnotator.<lambda>>}
schema
annotators
coerces_to
binary_coercions
def annotate(self, expression: ~E) -> ~E:
331    def annotate(self, expression: E) -> E:
332        for scope in traverse_scope(expression):
333            selects = {}
334            for name, source in scope.sources.items():
335                if not isinstance(source, Scope):
336                    continue
337                if isinstance(source.expression, exp.UDTF):
338                    values = []
339
340                    if isinstance(source.expression, exp.Lateral):
341                        if isinstance(source.expression.this, exp.Explode):
342                            values = [source.expression.this.this]
343                    else:
344                        values = source.expression.expressions[0].expressions
345
346                    if not values:
347                        continue
348
349                    selects[name] = {
350                        alias: column
351                        for alias, column in zip(
352                            source.expression.alias_column_names,
353                            values,
354                        )
355                    }
356                else:
357                    selects[name] = {
358                        select.alias_or_name: select for select in source.expression.selects
359                    }
360
361            # First annotate the current scope's column references
362            for col in scope.columns:
363                if not col.table:
364                    continue
365
366                source = scope.sources.get(col.table)
367                if isinstance(source, exp.Table):
368                    self._set_type(col, self.schema.get_column_type(source, col))
369                elif source and col.table in selects and col.name in selects[col.table]:
370                    self._set_type(col, selects[col.table][col.name].type)
371
372            # Then (possibly) annotate the remaining expressions in the scope
373            self._maybe_annotate(scope.expression)
374
375        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions