Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot._typing import E
  7from sqlglot.helper import ensure_list, subclasses
  8from sqlglot.optimizer.scope import Scope, traverse_scope
  9from sqlglot.schema import Schema, ensure_schema
 10
 11if t.TYPE_CHECKING:
 12    B = t.TypeVar("B", bound=exp.Binary)
 13
 14
 15def annotate_types(
 16    expression: E,
 17    schema: t.Optional[t.Dict | Schema] = None,
 18    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 19    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 20) -> E:
 21    """
 22    Infers the types of an expression, annotating its AST accordingly.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"y": {"cola": "SMALLINT"}}
 27        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 28        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 29        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 30        <Type.DOUBLE: 'DOUBLE'>
 31
 32    Args:
 33        expression: Expression to annotate.
 34        schema: Database schema.
 35        annotators: Maps expression type to corresponding annotation function.
 36        coerces_to: Maps expression type to set of types that it can be coerced into.
 37
 38    Returns:
 39        The expression annotated with types.
 40    """
 41
 42    schema = ensure_schema(schema)
 43
 44    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 45
 46
 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 48    return lambda self, e: self._annotate_with_type(e, data_type)
 49
 50
 51class _TypeAnnotator(type):
 52    def __new__(cls, clsname, bases, attrs):
 53        klass = super().__new__(cls, clsname, bases, attrs)
 54
 55        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
 56        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
 57        text_precedence = (
 58            exp.DataType.Type.TEXT,
 59            exp.DataType.Type.NVARCHAR,
 60            exp.DataType.Type.VARCHAR,
 61            exp.DataType.Type.NCHAR,
 62            exp.DataType.Type.CHAR,
 63        )
 64        numeric_precedence = (
 65            exp.DataType.Type.DOUBLE,
 66            exp.DataType.Type.FLOAT,
 67            exp.DataType.Type.DECIMAL,
 68            exp.DataType.Type.BIGINT,
 69            exp.DataType.Type.INT,
 70            exp.DataType.Type.SMALLINT,
 71            exp.DataType.Type.TINYINT,
 72        )
 73        timelike_precedence = (
 74            exp.DataType.Type.TIMESTAMPLTZ,
 75            exp.DataType.Type.TIMESTAMPTZ,
 76            exp.DataType.Type.TIMESTAMP,
 77            exp.DataType.Type.DATETIME,
 78            exp.DataType.Type.DATE,
 79        )
 80
 81        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
 82            coerces_to = set()
 83            for data_type in type_precedence:
 84                klass.COERCES_TO[data_type] = coerces_to.copy()
 85                coerces_to |= {data_type}
 86
 87        return klass
 88
 89
 90class TypeAnnotator(metaclass=_TypeAnnotator):
 91    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 92        exp.DataType.Type.BIGINT: {
 93            exp.ApproxDistinct,
 94            exp.ArraySize,
 95            exp.Count,
 96            exp.Length,
 97        },
 98        exp.DataType.Type.BOOLEAN: {
 99            exp.Between,
100            exp.Boolean,
101            exp.In,
102            exp.RegexpLike,
103        },
104        exp.DataType.Type.DATE: {
105            exp.CurrentDate,
106            exp.Date,
107            exp.DateAdd,
108            exp.DateFromParts,
109            exp.DateStrToDate,
110            exp.DateSub,
111            exp.DateTrunc,
112            exp.DiToDate,
113            exp.StrToDate,
114            exp.TimeStrToDate,
115            exp.TsOrDsToDate,
116        },
117        exp.DataType.Type.DATETIME: {
118            exp.CurrentDatetime,
119            exp.DatetimeAdd,
120            exp.DatetimeSub,
121        },
122        exp.DataType.Type.DOUBLE: {
123            exp.ApproxQuantile,
124            exp.Avg,
125            exp.Exp,
126            exp.Ln,
127            exp.Log,
128            exp.Log2,
129            exp.Log10,
130            exp.Pow,
131            exp.Quantile,
132            exp.Round,
133            exp.SafeDivide,
134            exp.Sqrt,
135            exp.Stddev,
136            exp.StddevPop,
137            exp.StddevSamp,
138            exp.Variance,
139            exp.VariancePop,
140        },
141        exp.DataType.Type.INT: {
142            exp.Ceil,
143            exp.DateDiff,
144            exp.DatetimeDiff,
145            exp.Extract,
146            exp.TimestampDiff,
147            exp.TimeDiff,
148            exp.DateToDi,
149            exp.Floor,
150            exp.Levenshtein,
151            exp.StrPosition,
152            exp.TsOrDiToDi,
153        },
154        exp.DataType.Type.TIMESTAMP: {
155            exp.CurrentTime,
156            exp.CurrentTimestamp,
157            exp.StrToTime,
158            exp.TimeAdd,
159            exp.TimeStrToTime,
160            exp.TimeSub,
161            exp.Timestamp,
162            exp.TimestampAdd,
163            exp.TimestampSub,
164            exp.UnixToTime,
165        },
166        exp.DataType.Type.TINYINT: {
167            exp.Day,
168            exp.Month,
169            exp.Week,
170            exp.Year,
171        },
172        exp.DataType.Type.VARCHAR: {
173            exp.ArrayConcat,
174            exp.Concat,
175            exp.ConcatWs,
176            exp.DateToDateStr,
177            exp.GroupConcat,
178            exp.Initcap,
179            exp.Lower,
180            exp.SafeConcat,
181            exp.SafeDPipe,
182            exp.Substring,
183            exp.TimeToStr,
184            exp.TimeToTimeStr,
185            exp.Trim,
186            exp.TsOrDsToDateStr,
187            exp.UnixToStr,
188            exp.UnixToTimeStr,
189            exp.Upper,
190        },
191    }
192
193    ANNOTATORS: t.Dict = {
194        **{
195            expr_type: lambda self, e: self._annotate_unary(e)
196            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
197        },
198        **{
199            expr_type: lambda self, e: self._annotate_binary(e)
200            for expr_type in subclasses(exp.__name__, exp.Binary)
201        },
202        **{
203            expr_type: _annotate_with_type_lambda(data_type)
204            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
205            for expr_type in expressions
206        },
207        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
208        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
209        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
210        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
211        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
212        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
213        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
214        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
215        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
216        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
217        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
218        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
219        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
220        exp.Literal: lambda self, e: self._annotate_literal(e),
221        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
222        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
223        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
224        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
225        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
226        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
227        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
228    }
229
230    NESTED_TYPES = {
231        exp.DataType.Type.ARRAY,
232    }
233
234    # Specifies what types a given type can be coerced into (autofilled)
235    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
236
237    def __init__(
238        self,
239        schema: Schema,
240        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
241        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
242    ) -> None:
243        self.schema = schema
244        self.annotators = annotators or self.ANNOTATORS
245        self.coerces_to = coerces_to or self.COERCES_TO
246
247        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
248        self._visited: t.Set[int] = set()
249
250    def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
251        expression.type = target_type
252        self._visited.add(id(expression))
253
254    def annotate(self, expression: E) -> E:
255        for scope in traverse_scope(expression):
256            selects = {}
257            for name, source in scope.sources.items():
258                if not isinstance(source, Scope):
259                    continue
260                if isinstance(source.expression, exp.UDTF):
261                    values = []
262
263                    if isinstance(source.expression, exp.Lateral):
264                        if isinstance(source.expression.this, exp.Explode):
265                            values = [source.expression.this.this]
266                    else:
267                        values = source.expression.expressions[0].expressions
268
269                    if not values:
270                        continue
271
272                    selects[name] = {
273                        alias: column
274                        for alias, column in zip(
275                            source.expression.alias_column_names,
276                            values,
277                        )
278                    }
279                else:
280                    selects[name] = {
281                        select.alias_or_name: select for select in source.expression.selects
282                    }
283
284            # First annotate the current scope's column references
285            for col in scope.columns:
286                if not col.table:
287                    continue
288
289                source = scope.sources.get(col.table)
290                if isinstance(source, exp.Table):
291                    self._set_type(col, self.schema.get_column_type(source, col))
292                elif source and col.table in selects and col.name in selects[col.table]:
293                    self._set_type(col, selects[col.table][col.name].type)
294
295            # Then (possibly) annotate the remaining expressions in the scope
296            self._maybe_annotate(scope.expression)
297
298        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
299
300    def _maybe_annotate(self, expression: E) -> E:
301        if id(expression) in self._visited:
302            return expression  # We've already inferred the expression's type
303
304        annotator = self.annotators.get(expression.__class__)
305
306        return (
307            annotator(self, expression)
308            if annotator
309            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
310        )
311
312    def _annotate_args(self, expression: E) -> E:
313        for _, value in expression.iter_expressions():
314            self._maybe_annotate(value)
315
316        return expression
317
318    def _maybe_coerce(
319        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
320    ) -> exp.DataType | exp.DataType.Type:
321        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
322        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
323
324        # We propagate the NULL / UNKNOWN types upwards if found
325        if exp.DataType.Type.NULL in (type1_value, type2_value):
326            return exp.DataType.Type.NULL
327        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
328            return exp.DataType.Type.UNKNOWN
329
330        if type1_value in self.NESTED_TYPES:
331            return type1
332        if type2_value in self.NESTED_TYPES:
333            return type2
334
335        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
336
337    # Note: the following "no_type_check" decorators were added because mypy was yelling due
338    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
339    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
340
341    @t.no_type_check
342    def _annotate_binary(self, expression: B) -> B:
343        self._annotate_args(expression)
344
345        left_type = expression.left.type.this
346        right_type = expression.right.type.this
347
348        if isinstance(expression, exp.Connector):
349            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
350                self._set_type(expression, exp.DataType.Type.NULL)
351            elif exp.DataType.Type.NULL in (left_type, right_type):
352                self._set_type(
353                    expression,
354                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
355                )
356            else:
357                self._set_type(expression, exp.DataType.Type.BOOLEAN)
358        elif isinstance(expression, exp.Predicate):
359            self._set_type(expression, exp.DataType.Type.BOOLEAN)
360        else:
361            self._set_type(expression, self._maybe_coerce(left_type, right_type))
362
363        return expression
364
365    @t.no_type_check
366    def _annotate_unary(self, expression: E) -> E:
367        self._annotate_args(expression)
368
369        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
370            self._set_type(expression, exp.DataType.Type.BOOLEAN)
371        else:
372            self._set_type(expression, expression.this.type)
373
374        return expression
375
376    @t.no_type_check
377    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
378        if expression.is_string:
379            self._set_type(expression, exp.DataType.Type.VARCHAR)
380        elif expression.is_int:
381            self._set_type(expression, exp.DataType.Type.INT)
382        else:
383            self._set_type(expression, exp.DataType.Type.DOUBLE)
384
385        return expression
386
387    @t.no_type_check
388    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
389        self._set_type(expression, target_type)
390        return self._annotate_args(expression)
391
392    @t.no_type_check
393    def _annotate_by_args(
394        self, expression: E, *args: str, promote: bool = False, array: bool = False
395    ) -> E:
396        self._annotate_args(expression)
397
398        expressions: t.List[exp.Expression] = []
399        for arg in args:
400            arg_expr = expression.args.get(arg)
401            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
402
403        last_datatype = None
404        for expr in expressions:
405            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
406
407        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
408
409        if promote:
410            if expression.type.this in exp.DataType.INTEGER_TYPES:
411                self._set_type(expression, exp.DataType.Type.BIGINT)
412            elif expression.type.this in exp.DataType.FLOAT_TYPES:
413                self._set_type(expression, exp.DataType.Type.DOUBLE)
414
415        if array:
416            self._set_type(
417                expression,
418                exp.DataType(
419                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
420                ),
421            )
422
423        return expression
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types(
17    expression: E,
18    schema: t.Optional[t.Dict | Schema] = None,
19    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
20    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
21) -> E:
22    """
23    Infers the types of an expression, annotating its AST accordingly.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"y": {"cola": "SMALLINT"}}
28        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
29        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
30        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
31        <Type.DOUBLE: 'DOUBLE'>
32
33    Args:
34        expression: Expression to annotate.
35        schema: Database schema.
36        annotators: Maps expression type to corresponding annotation function.
37        coerces_to: Maps expression type to set of types that it can be coerced into.
38
39    Returns:
40        The expression annotated with types.
41    """
42
43    schema = ensure_schema(schema)
44
45    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

class TypeAnnotator:
 91class TypeAnnotator(metaclass=_TypeAnnotator):
 92    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 93        exp.DataType.Type.BIGINT: {
 94            exp.ApproxDistinct,
 95            exp.ArraySize,
 96            exp.Count,
 97            exp.Length,
 98        },
 99        exp.DataType.Type.BOOLEAN: {
100            exp.Between,
101            exp.Boolean,
102            exp.In,
103            exp.RegexpLike,
104        },
105        exp.DataType.Type.DATE: {
106            exp.CurrentDate,
107            exp.Date,
108            exp.DateAdd,
109            exp.DateFromParts,
110            exp.DateStrToDate,
111            exp.DateSub,
112            exp.DateTrunc,
113            exp.DiToDate,
114            exp.StrToDate,
115            exp.TimeStrToDate,
116            exp.TsOrDsToDate,
117        },
118        exp.DataType.Type.DATETIME: {
119            exp.CurrentDatetime,
120            exp.DatetimeAdd,
121            exp.DatetimeSub,
122        },
123        exp.DataType.Type.DOUBLE: {
124            exp.ApproxQuantile,
125            exp.Avg,
126            exp.Exp,
127            exp.Ln,
128            exp.Log,
129            exp.Log2,
130            exp.Log10,
131            exp.Pow,
132            exp.Quantile,
133            exp.Round,
134            exp.SafeDivide,
135            exp.Sqrt,
136            exp.Stddev,
137            exp.StddevPop,
138            exp.StddevSamp,
139            exp.Variance,
140            exp.VariancePop,
141        },
142        exp.DataType.Type.INT: {
143            exp.Ceil,
144            exp.DateDiff,
145            exp.DatetimeDiff,
146            exp.Extract,
147            exp.TimestampDiff,
148            exp.TimeDiff,
149            exp.DateToDi,
150            exp.Floor,
151            exp.Levenshtein,
152            exp.StrPosition,
153            exp.TsOrDiToDi,
154        },
155        exp.DataType.Type.TIMESTAMP: {
156            exp.CurrentTime,
157            exp.CurrentTimestamp,
158            exp.StrToTime,
159            exp.TimeAdd,
160            exp.TimeStrToTime,
161            exp.TimeSub,
162            exp.Timestamp,
163            exp.TimestampAdd,
164            exp.TimestampSub,
165            exp.UnixToTime,
166        },
167        exp.DataType.Type.TINYINT: {
168            exp.Day,
169            exp.Month,
170            exp.Week,
171            exp.Year,
172        },
173        exp.DataType.Type.VARCHAR: {
174            exp.ArrayConcat,
175            exp.Concat,
176            exp.ConcatWs,
177            exp.DateToDateStr,
178            exp.GroupConcat,
179            exp.Initcap,
180            exp.Lower,
181            exp.SafeConcat,
182            exp.SafeDPipe,
183            exp.Substring,
184            exp.TimeToStr,
185            exp.TimeToTimeStr,
186            exp.Trim,
187            exp.TsOrDsToDateStr,
188            exp.UnixToStr,
189            exp.UnixToTimeStr,
190            exp.Upper,
191        },
192    }
193
194    ANNOTATORS: t.Dict = {
195        **{
196            expr_type: lambda self, e: self._annotate_unary(e)
197            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
198        },
199        **{
200            expr_type: lambda self, e: self._annotate_binary(e)
201            for expr_type in subclasses(exp.__name__, exp.Binary)
202        },
203        **{
204            expr_type: _annotate_with_type_lambda(data_type)
205            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
206            for expr_type in expressions
207        },
208        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
209        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
210        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
211        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
212        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
213        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
214        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
215        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
216        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
217        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
218        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
219        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
220        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
221        exp.Literal: lambda self, e: self._annotate_literal(e),
222        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
223        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
224        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
225        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
226        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
227        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
228        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
229    }
230
231    NESTED_TYPES = {
232        exp.DataType.Type.ARRAY,
233    }
234
235    # Specifies what types a given type can be coerced into (autofilled)
236    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
237
238    def __init__(
239        self,
240        schema: Schema,
241        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
242        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
243    ) -> None:
244        self.schema = schema
245        self.annotators = annotators or self.ANNOTATORS
246        self.coerces_to = coerces_to or self.COERCES_TO
247
248        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
249        self._visited: t.Set[int] = set()
250
251    def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
252        expression.type = target_type
253        self._visited.add(id(expression))
254
255    def annotate(self, expression: E) -> E:
256        for scope in traverse_scope(expression):
257            selects = {}
258            for name, source in scope.sources.items():
259                if not isinstance(source, Scope):
260                    continue
261                if isinstance(source.expression, exp.UDTF):
262                    values = []
263
264                    if isinstance(source.expression, exp.Lateral):
265                        if isinstance(source.expression.this, exp.Explode):
266                            values = [source.expression.this.this]
267                    else:
268                        values = source.expression.expressions[0].expressions
269
270                    if not values:
271                        continue
272
273                    selects[name] = {
274                        alias: column
275                        for alias, column in zip(
276                            source.expression.alias_column_names,
277                            values,
278                        )
279                    }
280                else:
281                    selects[name] = {
282                        select.alias_or_name: select for select in source.expression.selects
283                    }
284
285            # First annotate the current scope's column references
286            for col in scope.columns:
287                if not col.table:
288                    continue
289
290                source = scope.sources.get(col.table)
291                if isinstance(source, exp.Table):
292                    self._set_type(col, self.schema.get_column_type(source, col))
293                elif source and col.table in selects and col.name in selects[col.table]:
294                    self._set_type(col, selects[col.table][col.name].type)
295
296            # Then (possibly) annotate the remaining expressions in the scope
297            self._maybe_annotate(scope.expression)
298
299        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
300
301    def _maybe_annotate(self, expression: E) -> E:
302        if id(expression) in self._visited:
303            return expression  # We've already inferred the expression's type
304
305        annotator = self.annotators.get(expression.__class__)
306
307        return (
308            annotator(self, expression)
309            if annotator
310            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
311        )
312
313    def _annotate_args(self, expression: E) -> E:
314        for _, value in expression.iter_expressions():
315            self._maybe_annotate(value)
316
317        return expression
318
319    def _maybe_coerce(
320        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
321    ) -> exp.DataType | exp.DataType.Type:
322        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
323        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
324
325        # We propagate the NULL / UNKNOWN types upwards if found
326        if exp.DataType.Type.NULL in (type1_value, type2_value):
327            return exp.DataType.Type.NULL
328        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
329            return exp.DataType.Type.UNKNOWN
330
331        if type1_value in self.NESTED_TYPES:
332            return type1
333        if type2_value in self.NESTED_TYPES:
334            return type2
335
336        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
337
338    # Note: the following "no_type_check" decorators were added because mypy was yelling due
339    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
340    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
341
342    @t.no_type_check
343    def _annotate_binary(self, expression: B) -> B:
344        self._annotate_args(expression)
345
346        left_type = expression.left.type.this
347        right_type = expression.right.type.this
348
349        if isinstance(expression, exp.Connector):
350            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
351                self._set_type(expression, exp.DataType.Type.NULL)
352            elif exp.DataType.Type.NULL in (left_type, right_type):
353                self._set_type(
354                    expression,
355                    exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
356                )
357            else:
358                self._set_type(expression, exp.DataType.Type.BOOLEAN)
359        elif isinstance(expression, exp.Predicate):
360            self._set_type(expression, exp.DataType.Type.BOOLEAN)
361        else:
362            self._set_type(expression, self._maybe_coerce(left_type, right_type))
363
364        return expression
365
366    @t.no_type_check
367    def _annotate_unary(self, expression: E) -> E:
368        self._annotate_args(expression)
369
370        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
371            self._set_type(expression, exp.DataType.Type.BOOLEAN)
372        else:
373            self._set_type(expression, expression.this.type)
374
375        return expression
376
377    @t.no_type_check
378    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
379        if expression.is_string:
380            self._set_type(expression, exp.DataType.Type.VARCHAR)
381        elif expression.is_int:
382            self._set_type(expression, exp.DataType.Type.INT)
383        else:
384            self._set_type(expression, exp.DataType.Type.DOUBLE)
385
386        return expression
387
388    @t.no_type_check
389    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
390        self._set_type(expression, target_type)
391        return self._annotate_args(expression)
392
393    @t.no_type_check
394    def _annotate_by_args(
395        self, expression: E, *args: str, promote: bool = False, array: bool = False
396    ) -> E:
397        self._annotate_args(expression)
398
399        expressions: t.List[exp.Expression] = []
400        for arg in args:
401            arg_expr = expression.args.get(arg)
402            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
403
404        last_datatype = None
405        for expr in expressions:
406            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
407
408        self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
409
410        if promote:
411            if expression.type.this in exp.DataType.INTEGER_TYPES:
412                self._set_type(expression, exp.DataType.Type.BIGINT)
413            elif expression.type.this in exp.DataType.FLOAT_TYPES:
414                self._set_type(expression, exp.DataType.Type.DOUBLE)
415
416        if array:
417            self._set_type(
418                expression,
419                exp.DataType(
420                    this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
421                ),
422            )
423
424        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
238    def __init__(
239        self,
240        schema: Schema,
241        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
242        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
243    ) -> None:
244        self.schema = schema
245        self.annotators = annotators or self.ANNOTATORS
246        self.coerces_to = coerces_to or self.COERCES_TO
247
248        # Caches the ids of annotated sub-Expressions, to ensure we only visit them once
249        self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ApproxDistinct'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.Boolean'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.DateAdd'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Quantile'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.DatetimeDiff'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.TimeSub'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Day'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.ConcatWs'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.NCHAR: 'NCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NCHAR: 'NCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
schema
annotators
coerces_to
def annotate(self, expression: ~E) -> ~E:
255    def annotate(self, expression: E) -> E:
256        for scope in traverse_scope(expression):
257            selects = {}
258            for name, source in scope.sources.items():
259                if not isinstance(source, Scope):
260                    continue
261                if isinstance(source.expression, exp.UDTF):
262                    values = []
263
264                    if isinstance(source.expression, exp.Lateral):
265                        if isinstance(source.expression.this, exp.Explode):
266                            values = [source.expression.this.this]
267                    else:
268                        values = source.expression.expressions[0].expressions
269
270                    if not values:
271                        continue
272
273                    selects[name] = {
274                        alias: column
275                        for alias, column in zip(
276                            source.expression.alias_column_names,
277                            values,
278                        )
279                    }
280                else:
281                    selects[name] = {
282                        select.alias_or_name: select for select in source.expression.selects
283                    }
284
285            # First annotate the current scope's column references
286            for col in scope.columns:
287                if not col.table:
288                    continue
289
290                source = scope.sources.get(col.table)
291                if isinstance(source, exp.Table):
292                    self._set_type(col, self.schema.get_column_type(source, col))
293                elif source and col.table in selects and col.name in selects[col.table]:
294                    self._set_type(col, selects[col.table][col.name].type)
295
296            # Then (possibly) annotate the remaining expressions in the scope
297            self._maybe_annotate(scope.expression)
298
299        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions