Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import datetime
  4import functools
  5import typing as t
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.helper import ensure_list, seq_get, subclasses
 10from sqlglot.optimizer.scope import Scope, traverse_scope
 11from sqlglot.schema import Schema, ensure_schema
 12
 13if t.TYPE_CHECKING:
 14    B = t.TypeVar("B", bound=exp.Binary)
 15
 16    BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
 17    BinaryCoercions = t.Dict[
 18        t.Tuple[exp.DataType.Type, exp.DataType.Type],
 19        BinaryCoercionFunc,
 20    ]
 21
 22
 23# Interval units that operate on date components
 24DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
 25
 26
 27def annotate_types(
 28    expression: E,
 29    schema: t.Optional[t.Dict | Schema] = None,
 30    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 31    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 32) -> E:
 33    """
 34    Infers the types of an expression, annotating its AST accordingly.
 35
 36    Example:
 37        >>> import sqlglot
 38        >>> schema = {"y": {"cola": "SMALLINT"}}
 39        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 40        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 41        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 42        <Type.DOUBLE: 'DOUBLE'>
 43
 44    Args:
 45        expression: Expression to annotate.
 46        schema: Database schema.
 47        annotators: Maps expression type to corresponding annotation function.
 48        coerces_to: Maps expression type to set of types that it can be coerced into.
 49
 50    Returns:
 51        The expression annotated with types.
 52    """
 53
 54    schema = ensure_schema(schema)
 55
 56    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 57
 58
 59def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 60    return lambda self, e: self._annotate_with_type(e, data_type)
 61
 62
 63def _is_iso_date(text: str) -> bool:
 64    try:
 65        datetime.date.fromisoformat(text)
 66        return True
 67    except ValueError:
 68        return False
 69
 70
 71def _is_iso_datetime(text: str) -> bool:
 72    try:
 73        datetime.datetime.fromisoformat(text)
 74        return True
 75    except ValueError:
 76        return False
 77
 78
 79def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
 80    date_text = l.name
 81    unit = r.text("unit").lower()
 82
 83    is_iso_date = _is_iso_date(date_text)
 84
 85    if is_iso_date and unit in DATE_UNITS:
 86        l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
 87        return exp.DataType.Type.DATE
 88
 89    # An ISO date is also an ISO datetime, but not vice versa
 90    if is_iso_date or _is_iso_datetime(date_text):
 91        l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
 92        return exp.DataType.Type.DATETIME
 93
 94    return exp.DataType.Type.UNKNOWN
 95
 96
 97def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
 98    unit = r.text("unit").lower()
 99    if unit not in DATE_UNITS:
100        return exp.DataType.Type.DATETIME
101    return l.type.this if l.type else exp.DataType.Type.UNKNOWN
102
103
104def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
105    @functools.wraps(func)
106    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
107        return func(r, l)
108
109    return _swapped
110
111
112def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
113    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
114
115
116class _TypeAnnotator(type):
117    def __new__(cls, clsname, bases, attrs):
118        klass = super().__new__(cls, clsname, bases, attrs)
119
120        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
121        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
122        text_precedence = (
123            exp.DataType.Type.TEXT,
124            exp.DataType.Type.NVARCHAR,
125            exp.DataType.Type.VARCHAR,
126            exp.DataType.Type.NCHAR,
127            exp.DataType.Type.CHAR,
128        )
129        numeric_precedence = (
130            exp.DataType.Type.DOUBLE,
131            exp.DataType.Type.FLOAT,
132            exp.DataType.Type.DECIMAL,
133            exp.DataType.Type.BIGINT,
134            exp.DataType.Type.INT,
135            exp.DataType.Type.SMALLINT,
136            exp.DataType.Type.TINYINT,
137        )
138        timelike_precedence = (
139            exp.DataType.Type.TIMESTAMPLTZ,
140            exp.DataType.Type.TIMESTAMPTZ,
141            exp.DataType.Type.TIMESTAMP,
142            exp.DataType.Type.DATETIME,
143            exp.DataType.Type.DATE,
144        )
145
146        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
147            coerces_to = set()
148            for data_type in type_precedence:
149                klass.COERCES_TO[data_type] = coerces_to.copy()
150                coerces_to |= {data_type}
151
152        return klass
153
154
155class TypeAnnotator(metaclass=_TypeAnnotator):
156    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
157        exp.DataType.Type.BIGINT: {
158            exp.ApproxDistinct,
159            exp.ArraySize,
160            exp.Count,
161            exp.Length,
162        },
163        exp.DataType.Type.BOOLEAN: {
164            exp.Between,
165            exp.Boolean,
166            exp.In,
167            exp.RegexpLike,
168        },
169        exp.DataType.Type.DATE: {
170            exp.CurrentDate,
171            exp.Date,
172            exp.DateFromParts,
173            exp.DateStrToDate,
174            exp.DateTrunc,
175            exp.DiToDate,
176            exp.StrToDate,
177            exp.TimeStrToDate,
178            exp.TsOrDsToDate,
179        },
180        exp.DataType.Type.DATETIME: {
181            exp.CurrentDatetime,
182            exp.DatetimeAdd,
183            exp.DatetimeSub,
184        },
185        exp.DataType.Type.DOUBLE: {
186            exp.ApproxQuantile,
187            exp.Avg,
188            exp.Exp,
189            exp.Ln,
190            exp.Log,
191            exp.Log2,
192            exp.Log10,
193            exp.Pow,
194            exp.Quantile,
195            exp.Round,
196            exp.SafeDivide,
197            exp.Sqrt,
198            exp.Stddev,
199            exp.StddevPop,
200            exp.StddevSamp,
201            exp.Variance,
202            exp.VariancePop,
203        },
204        exp.DataType.Type.INT: {
205            exp.Ceil,
206            exp.DateDiff,
207            exp.DatetimeDiff,
208            exp.Extract,
209            exp.TimestampDiff,
210            exp.TimeDiff,
211            exp.DateToDi,
212            exp.Floor,
213            exp.Levenshtein,
214            exp.StrPosition,
215            exp.TsOrDiToDi,
216        },
217        exp.DataType.Type.TIMESTAMP: {
218            exp.CurrentTime,
219            exp.CurrentTimestamp,
220            exp.StrToTime,
221            exp.TimeAdd,
222            exp.TimeStrToTime,
223            exp.TimeSub,
224            exp.Timestamp,
225            exp.TimestampAdd,
226            exp.TimestampSub,
227            exp.UnixToTime,
228        },
229        exp.DataType.Type.TINYINT: {
230            exp.Day,
231            exp.Month,
232            exp.Week,
233            exp.Year,
234        },
235        exp.DataType.Type.VARCHAR: {
236            exp.ArrayConcat,
237            exp.Concat,
238            exp.ConcatWs,
239            exp.DateToDateStr,
240            exp.GroupConcat,
241            exp.Initcap,
242            exp.Lower,
243            exp.SafeConcat,
244            exp.SafeDPipe,
245            exp.Substring,
246            exp.TimeToStr,
247            exp.TimeToTimeStr,
248            exp.Trim,
249            exp.TsOrDsToDateStr,
250            exp.UnixToStr,
251            exp.UnixToTimeStr,
252            exp.Upper,
253        },
254    }
255
256    ANNOTATORS: t.Dict = {
257        **{
258            expr_type: lambda self, e: self._annotate_unary(e)
259            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
260        },
261        **{
262            expr_type: lambda self, e: self._annotate_binary(e)
263            for expr_type in subclasses(exp.__name__, exp.Binary)
264        },
265        **{
266            expr_type: _annotate_with_type_lambda(data_type)
267            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
268            for expr_type in expressions
269        },
270        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
271        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
272        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
273        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
274        exp.Bracket: lambda self, e: self._annotate_bracket(e),
275        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
276        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
277        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
278        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
279        exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
280        exp.DateSub: lambda self, e: self._annotate_dateadd(e),
281        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
282        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
283        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
284        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
285        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
286        exp.Literal: lambda self, e: self._annotate_literal(e),
287        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
288        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
289        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
290        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
291        exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
292        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
293        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
294        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
295    }
296
297    NESTED_TYPES = {
298        exp.DataType.Type.ARRAY,
299    }
300
301    # Specifies what types a given type can be coerced into (autofilled)
302    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
303
304    # Coercion functions for binary operations.
305    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
306    BINARY_COERCIONS: BinaryCoercions = {
307        **swap_all(
308            {
309                (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
310                for t in exp.DataType.TEXT_TYPES
311            }
312        ),
313        **swap_all(
314            {
315                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
316            }
317        ),
318    }
319
320    def __init__(
321        self,
322        schema: Schema,
323        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
324        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
325        binary_coercions: t.Optional[BinaryCoercions] = None,
326    ) -> None:
327        self.schema = schema
328        self.annotators = annotators or self.ANNOTATORS
329        self.coerces_to = coerces_to or self.COERCES_TO
330        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
331
332        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
333        self._visited: t.Set[int] = set()
334
335    def _set_type(
336        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
337    ) -> None:
338        expression.type = target_type  # type: ignore
339        self._visited.add(id(expression))
340
341    def annotate(self, expression: E) -> E:
342        for scope in traverse_scope(expression):
343            selects = {}
344            for name, source in scope.sources.items():
345                if not isinstance(source, Scope):
346                    continue
347                if isinstance(source.expression, exp.UDTF):
348                    values = []
349
350                    if isinstance(source.expression, exp.Lateral):
351                        if isinstance(source.expression.this, exp.Explode):
352                            values = [source.expression.this.this]
353                    else:
354                        values = source.expression.expressions[0].expressions
355
356                    if not values:
357                        continue
358
359                    selects[name] = {
360                        alias: column
361                        for alias, column in zip(
362                            source.expression.alias_column_names,
363                            values,
364                        )
365                    }
366                else:
367                    selects[name] = {
368                        select.alias_or_name: select for select in source.expression.selects
369                    }
370
371            # First annotate the current scope's column references
372            for col in scope.columns:
373                if not col.table:
374                    continue
375
376                source = scope.sources.get(col.table)
377                if isinstance(source, exp.Table):
378                    self._set_type(col, self.schema.get_column_type(source, col))
379                elif source and col.table in selects and col.name in selects[col.table]:
380                    self._set_type(col, selects[col.table][col.name].type)
381
382            # Then (possibly) annotate the remaining expressions in the scope
383            self._maybe_annotate(scope.expression)
384
385        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
386
387    def _maybe_annotate(self, expression: E) -> E:
388        if id(expression) in self._visited:
389            return expression  # We've already inferred the expression's type
390
391        annotator = self.annotators.get(expression.__class__)
392
393        return (
394            annotator(self, expression)
395            if annotator
396            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
397        )
398
399    def _annotate_args(self, expression: E) -> E:
400        for _, value in expression.iter_expressions():
401            self._maybe_annotate(value)
402
403        return expression
404
405    def _maybe_coerce(
406        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
407    ) -> exp.DataType | exp.DataType.Type:
408        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
409        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
410
411        # We propagate the NULL / UNKNOWN types upwards if found
412        if exp.DataType.Type.NULL in (type1_value, type2_value):
413            return exp.DataType.Type.NULL
414        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
415            return exp.DataType.Type.UNKNOWN
416
417        if type1_value in self.NESTED_TYPES:
418            return type1
419        if type2_value in self.NESTED_TYPES:
420            return type2
421
422        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
423
424    # Note: the following "no_type_check" decorators were added because mypy was yelling due
425    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
426    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
427
428    @t.no_type_check
429    def _annotate_binary(self, expression: B) -> B:
430        self._annotate_args(expression)
431
432        left, right = expression.left, expression.right
433        left_type, right_type = left.type.this, right.type.this
434
435        if isinstance(expression, exp.Connector):
436            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
437                self._set_type(expression, exp.DataType.Type.NULL)
438            elif exp.DataType.Type.NULL in (left_type, right_type):
439                self._set_type(
440                    expression,
441                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
442                )
443            else:
444                self._set_type(expression, exp.DataType.Type.BOOLEAN)
445        elif isinstance(expression, exp.Predicate):
446            self._set_type(expression, exp.DataType.Type.BOOLEAN)
447        elif (left_type, right_type) in self.binary_coercions:
448            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
449        else:
450            self._set_type(expression, self._maybe_coerce(left_type, right_type))
451
452        return expression
453
454    @t.no_type_check
455    def _annotate_unary(self, expression: E) -> E:
456        self._annotate_args(expression)
457
458        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
459            self._set_type(expression, exp.DataType.Type.BOOLEAN)
460        else:
461            self._set_type(expression, expression.this.type)
462
463        return expression
464
465    @t.no_type_check
466    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
467        if expression.is_string:
468            self._set_type(expression, exp.DataType.Type.VARCHAR)
469        elif expression.is_int:
470            self._set_type(expression, exp.DataType.Type.INT)
471        else:
472            self._set_type(expression, exp.DataType.Type.DOUBLE)
473
474        return expression
475
476    @t.no_type_check
477    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
478        self._set_type(expression, target_type)
479        return self._annotate_args(expression)
480
481    @t.no_type_check
482    def _annotate_by_args(
483        self, expression: E, *args: str, promote: bool = False, array: bool = False
484    ) -> E:
485        self._annotate_args(expression)
486
487        expressions: t.List[exp.Expression] = []
488        for arg in args:
489            arg_expr = expression.args.get(arg)
490            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
491
492        last_datatype = None
493        for expr in expressions:
494            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
495
496        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
497
498        if promote:
499            if expression.type.this in exp.DataType.INTEGER_TYPES:
500                self._set_type(expression, exp.DataType.Type.BIGINT)
501            elif expression.type.this in exp.DataType.FLOAT_TYPES:
502                self._set_type(expression, exp.DataType.Type.DOUBLE)
503
504        if array:
505            self._set_type(
506                expression,
507                exp.DataType(
508                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
509                ),
510            )
511
512        return expression
513
514    def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
515        self._annotate_args(expression)
516
517        if expression.this.type.this in exp.DataType.TEXT_TYPES:
518            datatype = _coerce_literal_and_interval(expression.this, expression.interval())
519        elif (
520            expression.this.type.is_type(exp.DataType.Type.DATE)
521            and expression.text("unit").lower() not in DATE_UNITS
522        ):
523            datatype = exp.DataType.Type.DATETIME
524        else:
525            datatype = expression.this.type
526
527        self._set_type(expression, datatype)
528        return expression
529
530    def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
531        self._annotate_args(expression)
532
533        bracket_arg = expression.expressions[0]
534        this = expression.this
535
536        if isinstance(bracket_arg, exp.Slice):
537            self._set_type(expression, this.type)
538        elif this.type.is_type(exp.DataType.Type.ARRAY):
539            contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
540            self._set_type(expression, contained_type)
541        elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
542            index = this.keys.index(bracket_arg)
543            value = seq_get(this.values, index)
544            value_type = value.type if value else exp.DataType.Type.UNKNOWN
545            self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
546        else:
547            self._set_type(expression, exp.DataType.Type.UNKNOWN)
548
549        return expression
DATE_UNITS = {'year', 'day', 'month', 'year_month', 'week', 'quarter'}
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
28def annotate_types(
29    expression: E,
30    schema: t.Optional[t.Dict | Schema] = None,
31    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
32    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
33) -> E:
34    """
35    Infers the types of an expression, annotating its AST accordingly.
36
37    Example:
38        >>> import sqlglot
39        >>> schema = {"y": {"cola": "SMALLINT"}}
40        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
41        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
42        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
43        <Type.DOUBLE: 'DOUBLE'>
44
45    Args:
46        expression: Expression to annotate.
47        schema: Database schema.
48        annotators: Maps expression type to corresponding annotation function.
49        coerces_to: Maps expression type to set of types that it can be coerced into.
50
51    Returns:
52        The expression annotated with types.
53    """
54
55    schema = ensure_schema(schema)
56
57    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

105def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
106    @functools.wraps(func)
107    def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
108        return func(r, l)
109
110    return _swapped
113def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
114    return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
class TypeAnnotator:
156class TypeAnnotator(metaclass=_TypeAnnotator):
157    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
158        exp.DataType.Type.BIGINT: {
159            exp.ApproxDistinct,
160            exp.ArraySize,
161            exp.Count,
162            exp.Length,
163        },
164        exp.DataType.Type.BOOLEAN: {
165            exp.Between,
166            exp.Boolean,
167            exp.In,
168            exp.RegexpLike,
169        },
170        exp.DataType.Type.DATE: {
171            exp.CurrentDate,
172            exp.Date,
173            exp.DateFromParts,
174            exp.DateStrToDate,
175            exp.DateTrunc,
176            exp.DiToDate,
177            exp.StrToDate,
178            exp.TimeStrToDate,
179            exp.TsOrDsToDate,
180        },
181        exp.DataType.Type.DATETIME: {
182            exp.CurrentDatetime,
183            exp.DatetimeAdd,
184            exp.DatetimeSub,
185        },
186        exp.DataType.Type.DOUBLE: {
187            exp.ApproxQuantile,
188            exp.Avg,
189            exp.Exp,
190            exp.Ln,
191            exp.Log,
192            exp.Log2,
193            exp.Log10,
194            exp.Pow,
195            exp.Quantile,
196            exp.Round,
197            exp.SafeDivide,
198            exp.Sqrt,
199            exp.Stddev,
200            exp.StddevPop,
201            exp.StddevSamp,
202            exp.Variance,
203            exp.VariancePop,
204        },
205        exp.DataType.Type.INT: {
206            exp.Ceil,
207            exp.DateDiff,
208            exp.DatetimeDiff,
209            exp.Extract,
210            exp.TimestampDiff,
211            exp.TimeDiff,
212            exp.DateToDi,
213            exp.Floor,
214            exp.Levenshtein,
215            exp.StrPosition,
216            exp.TsOrDiToDi,
217        },
218        exp.DataType.Type.TIMESTAMP: {
219            exp.CurrentTime,
220            exp.CurrentTimestamp,
221            exp.StrToTime,
222            exp.TimeAdd,
223            exp.TimeStrToTime,
224            exp.TimeSub,
225            exp.Timestamp,
226            exp.TimestampAdd,
227            exp.TimestampSub,
228            exp.UnixToTime,
229        },
230        exp.DataType.Type.TINYINT: {
231            exp.Day,
232            exp.Month,
233            exp.Week,
234            exp.Year,
235        },
236        exp.DataType.Type.VARCHAR: {
237            exp.ArrayConcat,
238            exp.Concat,
239            exp.ConcatWs,
240            exp.DateToDateStr,
241            exp.GroupConcat,
242            exp.Initcap,
243            exp.Lower,
244            exp.SafeConcat,
245            exp.SafeDPipe,
246            exp.Substring,
247            exp.TimeToStr,
248            exp.TimeToTimeStr,
249            exp.Trim,
250            exp.TsOrDsToDateStr,
251            exp.UnixToStr,
252            exp.UnixToTimeStr,
253            exp.Upper,
254        },
255    }
256
257    ANNOTATORS: t.Dict = {
258        **{
259            expr_type: lambda self, e: self._annotate_unary(e)
260            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
261        },
262        **{
263            expr_type: lambda self, e: self._annotate_binary(e)
264            for expr_type in subclasses(exp.__name__, exp.Binary)
265        },
266        **{
267            expr_type: _annotate_with_type_lambda(data_type)
268            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
269            for expr_type in expressions
270        },
271        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
272        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
273        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
274        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
275        exp.Bracket: lambda self, e: self._annotate_bracket(e),
276        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
277        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
278        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
279        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
280        exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
281        exp.DateSub: lambda self, e: self._annotate_dateadd(e),
282        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
283        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
284        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
285        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
286        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
287        exp.Literal: lambda self, e: self._annotate_literal(e),
288        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
289        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
290        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
291        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
292        exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
293        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
294        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
295        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
296    }
297
298    NESTED_TYPES = {
299        exp.DataType.Type.ARRAY,
300    }
301
302    # Specifies what types a given type can be coerced into (autofilled)
303    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
304
305    # Coercion functions for binary operations.
306    # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
307    BINARY_COERCIONS: BinaryCoercions = {
308        **swap_all(
309            {
310                (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
311                for t in exp.DataType.TEXT_TYPES
312            }
313        ),
314        **swap_all(
315            {
316                (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
317            }
318        ),
319    }
320
321    def __init__(
322        self,
323        schema: Schema,
324        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
325        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
326        binary_coercions: t.Optional[BinaryCoercions] = None,
327    ) -> None:
328        self.schema = schema
329        self.annotators = annotators or self.ANNOTATORS
330        self.coerces_to = coerces_to or self.COERCES_TO
331        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
332
333        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
334        self._visited: t.Set[int] = set()
335
336    def _set_type(
337        self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
338    ) -> None:
339        expression.type = target_type  # type: ignore
340        self._visited.add(id(expression))
341
342    def annotate(self, expression: E) -> E:
343        for scope in traverse_scope(expression):
344            selects = {}
345            for name, source in scope.sources.items():
346                if not isinstance(source, Scope):
347                    continue
348                if isinstance(source.expression, exp.UDTF):
349                    values = []
350
351                    if isinstance(source.expression, exp.Lateral):
352                        if isinstance(source.expression.this, exp.Explode):
353                            values = [source.expression.this.this]
354                    else:
355                        values = source.expression.expressions[0].expressions
356
357                    if not values:
358                        continue
359
360                    selects[name] = {
361                        alias: column
362                        for alias, column in zip(
363                            source.expression.alias_column_names,
364                            values,
365                        )
366                    }
367                else:
368                    selects[name] = {
369                        select.alias_or_name: select for select in source.expression.selects
370                    }
371
372            # First annotate the current scope's column references
373            for col in scope.columns:
374                if not col.table:
375                    continue
376
377                source = scope.sources.get(col.table)
378                if isinstance(source, exp.Table):
379                    self._set_type(col, self.schema.get_column_type(source, col))
380                elif source and col.table in selects and col.name in selects[col.table]:
381                    self._set_type(col, selects[col.table][col.name].type)
382
383            # Then (possibly) annotate the remaining expressions in the scope
384            self._maybe_annotate(scope.expression)
385
386        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
387
388    def _maybe_annotate(self, expression: E) -> E:
389        if id(expression) in self._visited:
390            return expression  # We've already inferred the expression's type
391
392        annotator = self.annotators.get(expression.__class__)
393
394        return (
395            annotator(self, expression)
396            if annotator
397            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
398        )
399
400    def _annotate_args(self, expression: E) -> E:
401        for _, value in expression.iter_expressions():
402            self._maybe_annotate(value)
403
404        return expression
405
406    def _maybe_coerce(
407        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
408    ) -> exp.DataType | exp.DataType.Type:
409        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
410        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
411
412        # We propagate the NULL / UNKNOWN types upwards if found
413        if exp.DataType.Type.NULL in (type1_value, type2_value):
414            return exp.DataType.Type.NULL
415        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
416            return exp.DataType.Type.UNKNOWN
417
418        if type1_value in self.NESTED_TYPES:
419            return type1
420        if type2_value in self.NESTED_TYPES:
421            return type2
422
423        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
424
425    # Note: the following "no_type_check" decorators were added because mypy was yelling due
426    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
427    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
428
429    @t.no_type_check
430    def _annotate_binary(self, expression: B) -> B:
431        self._annotate_args(expression)
432
433        left, right = expression.left, expression.right
434        left_type, right_type = left.type.this, right.type.this
435
436        if isinstance(expression, exp.Connector):
437            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
438                self._set_type(expression, exp.DataType.Type.NULL)
439            elif exp.DataType.Type.NULL in (left_type, right_type):
440                self._set_type(
441                    expression,
442                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
443                )
444            else:
445                self._set_type(expression, exp.DataType.Type.BOOLEAN)
446        elif isinstance(expression, exp.Predicate):
447            self._set_type(expression, exp.DataType.Type.BOOLEAN)
448        elif (left_type, right_type) in self.binary_coercions:
449            self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
450        else:
451            self._set_type(expression, self._maybe_coerce(left_type, right_type))
452
453        return expression
454
455    @t.no_type_check
456    def _annotate_unary(self, expression: E) -> E:
457        self._annotate_args(expression)
458
459        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
460            self._set_type(expression, exp.DataType.Type.BOOLEAN)
461        else:
462            self._set_type(expression, expression.this.type)
463
464        return expression
465
466    @t.no_type_check
467    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
468        if expression.is_string:
469            self._set_type(expression, exp.DataType.Type.VARCHAR)
470        elif expression.is_int:
471            self._set_type(expression, exp.DataType.Type.INT)
472        else:
473            self._set_type(expression, exp.DataType.Type.DOUBLE)
474
475        return expression
476
477    @t.no_type_check
478    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
479        self._set_type(expression, target_type)
480        return self._annotate_args(expression)
481
482    @t.no_type_check
483    def _annotate_by_args(
484        self, expression: E, *args: str, promote: bool = False, array: bool = False
485    ) -> E:
486        self._annotate_args(expression)
487
488        expressions: t.List[exp.Expression] = []
489        for arg in args:
490            arg_expr = expression.args.get(arg)
491            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
492
493        last_datatype = None
494        for expr in expressions:
495            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
496
497        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
498
499        if promote:
500            if expression.type.this in exp.DataType.INTEGER_TYPES:
501                self._set_type(expression, exp.DataType.Type.BIGINT)
502            elif expression.type.this in exp.DataType.FLOAT_TYPES:
503                self._set_type(expression, exp.DataType.Type.DOUBLE)
504
505        if array:
506            self._set_type(
507                expression,
508                exp.DataType(
509                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
510                ),
511            )
512
513        return expression
514
515    def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
516        self._annotate_args(expression)
517
518        if expression.this.type.this in exp.DataType.TEXT_TYPES:
519            datatype = _coerce_literal_and_interval(expression.this, expression.interval())
520        elif (
521            expression.this.type.is_type(exp.DataType.Type.DATE)
522            and expression.text("unit").lower() not in DATE_UNITS
523        ):
524            datatype = exp.DataType.Type.DATETIME
525        else:
526            datatype = expression.this.type
527
528        self._set_type(expression, datatype)
529        return expression
530
531    def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
532        self._annotate_args(expression)
533
534        bracket_arg = expression.expressions[0]
535        this = expression.this
536
537        if isinstance(bracket_arg, exp.Slice):
538            self._set_type(expression, this.type)
539        elif this.type.is_type(exp.DataType.Type.ARRAY):
540            contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
541            self._set_type(expression, contained_type)
542        elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
543            index = this.keys.index(bracket_arg)
544            value = seq_get(this.values, index)
545            value_type = value.type if value else exp.DataType.Type.UNKNOWN
546            self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
547        else:
548            self._set_type(expression, exp.DataType.Type.UNKNOWN)
549
550        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
321    def __init__(
322        self,
323        schema: Schema,
324        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
325        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
326        binary_coercions: t.Optional[BinaryCoercions] = None,
327    ) -> None:
328        self.schema = schema
329        self.annotators = annotators or self.ANNOTATORS
330        self.coerces_to = coerces_to or self.COERCES_TO
331        self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
332
333        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
334        self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Length'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.StrToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.CurrentDatetime'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Ln'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.TsOrDiToDi'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimestampSub'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Week'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.DateToDateStr'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Bracket'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.NCHAR: 'NCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TINYINT: 'TINYINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] = {(<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function _coerce_literal_and_interval>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_date_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function _coerce_date_and_interval>}
schema
annotators
coerces_to
binary_coercions
def annotate(self, expression: ~E) -> ~E:
342    def annotate(self, expression: E) -> E:
343        for scope in traverse_scope(expression):
344            selects = {}
345            for name, source in scope.sources.items():
346                if not isinstance(source, Scope):
347                    continue
348                if isinstance(source.expression, exp.UDTF):
349                    values = []
350
351                    if isinstance(source.expression, exp.Lateral):
352                        if isinstance(source.expression.this, exp.Explode):
353                            values = [source.expression.this.this]
354                    else:
355                        values = source.expression.expressions[0].expressions
356
357                    if not values:
358                        continue
359
360                    selects[name] = {
361                        alias: column
362                        for alias, column in zip(
363                            source.expression.alias_column_names,
364                            values,
365                        )
366                    }
367                else:
368                    selects[name] = {
369                        select.alias_or_name: select for select in source.expression.selects
370                    }
371
372            # First annotate the current scope's column references
373            for col in scope.columns:
374                if not col.table:
375                    continue
376
377                source = scope.sources.get(col.table)
378                if isinstance(source, exp.Table):
379                    self._set_type(col, self.schema.get_column_type(source, col))
380                elif source and col.table in selects and col.name in selects[col.table]:
381                    self._set_type(col, selects[col.table][col.name].type)
382
383            # Then (possibly) annotate the remaining expressions in the scope
384            self._maybe_annotate(scope.expression)
385
386        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions