Edit on GitHub

sqlglot.optimizer.annotate_types

  1from sqlglot import exp
  2from sqlglot.helper import ensure_list, subclasses
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.schema import ensure_schema
  5
  6
  7def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
  8    """
  9    Recursively infer & annotate types in an expression syntax tree against a schema.
 10    Assumes that we've already executed the optimizer's qualify_columns step.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> schema = {"y": {"cola": "SMALLINT"}}
 15        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 16        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 17        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 18        <Type.DOUBLE: 'DOUBLE'>
 19
 20    Args:
 21        expression (sqlglot.Expression): Expression to annotate.
 22        schema (dict|sqlglot.optimizer.Schema): Database schema.
 23        annotators (dict): Maps expression type to corresponding annotation function.
 24        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
 25    Returns:
 26        sqlglot.Expression: expression annotated with types
 27    """
 28
 29    schema = ensure_schema(schema)
 30
 31    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 32
 33
 34class TypeAnnotator:
 35    ANNOTATORS = {
 36        **{
 37            expr_type: lambda self, expr: self._annotate_unary(expr)
 38            for expr_type in subclasses(exp.__name__, exp.Unary)
 39        },
 40        **{
 41            expr_type: lambda self, expr: self._annotate_binary(expr)
 42            for expr_type in subclasses(exp.__name__, exp.Binary)
 43        },
 44        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 45        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 47        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 48        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 49        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 51        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 52        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 53        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 54        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 55            expr, exp.DataType.Type.BIGINT
 56        ),
 57        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 58        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 59        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Sum: lambda self, expr: self._annotate_by_args(
 61            expr, "this", "expressions", promote=True
 62        ),
 63        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 64        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 65        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 66        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 67            expr, exp.DataType.Type.DATETIME
 68        ),
 69        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 70            expr, exp.DataType.Type.TIMESTAMP
 71        ),
 72        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 73            expr, exp.DataType.Type.TIMESTAMP
 74        ),
 75        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 76        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 78        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 79            expr, exp.DataType.Type.DATETIME
 80        ),
 81        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 82            expr, exp.DataType.Type.DATETIME
 83        ),
 84        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 85        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 87            expr, exp.DataType.Type.TIMESTAMP
 88        ),
 89        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 90            expr, exp.DataType.Type.TIMESTAMP
 91        ),
 92        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 93        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 94        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 96        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 97            expr, exp.DataType.Type.DATE
 98        ),
 99        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
100            expr, exp.DataType.Type.VARCHAR
101        ),
102        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
103        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
104        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
105        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
106        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
107        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
108        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
109        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
110        exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
111        exp.SafeConcat: lambda self, expr: self._annotate_with_type(
112            expr, exp.DataType.Type.VARCHAR
113        ),
114        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
115        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
116            expr, exp.DataType.Type.VARCHAR
117        ),
118        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
119            expr, exp.DataType.Type.VARCHAR
120        ),
121        exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
122        exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
123        exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
124        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
125        exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
126        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
127        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
128        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
129        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
130        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
131        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
132        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
133        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
134        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
135        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
136        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
137        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
138            expr, exp.DataType.Type.DOUBLE
139        ),
140        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
141            expr, exp.DataType.Type.BOOLEAN
142        ),
143        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
144        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
146        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
147        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
148        exp.StrToTime: lambda self, expr: self._annotate_with_type(
149            expr, exp.DataType.Type.TIMESTAMP
150        ),
151        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
152        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
153        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
154        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
155        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
156        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
157            expr, exp.DataType.Type.VARCHAR
158        ),
159        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
160            expr, exp.DataType.Type.DATE
161        ),
162        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
163            expr, exp.DataType.Type.TIMESTAMP
164        ),
165        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
166        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
167            expr, exp.DataType.Type.VARCHAR
168        ),
169        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
170        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
171        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
172        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
173            expr, exp.DataType.Type.TIMESTAMP
174        ),
175        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
176            expr, exp.DataType.Type.VARCHAR
177        ),
178        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
179        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
180        exp.VariancePop: lambda self, expr: self._annotate_with_type(
181            expr, exp.DataType.Type.DOUBLE
182        ),
183        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
184        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
185    }
186
187    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
188    COERCES_TO = {
189        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
190        exp.DataType.Type.TEXT: set(),
191        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
192        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
193        exp.DataType.Type.NCHAR: {
194            exp.DataType.Type.VARCHAR,
195            exp.DataType.Type.NVARCHAR,
196            exp.DataType.Type.TEXT,
197        },
198        exp.DataType.Type.CHAR: {
199            exp.DataType.Type.NCHAR,
200            exp.DataType.Type.VARCHAR,
201            exp.DataType.Type.NVARCHAR,
202            exp.DataType.Type.TEXT,
203        },
204        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
205        exp.DataType.Type.DOUBLE: set(),
206        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
207        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
208        exp.DataType.Type.BIGINT: {
209            exp.DataType.Type.DECIMAL,
210            exp.DataType.Type.FLOAT,
211            exp.DataType.Type.DOUBLE,
212        },
213        exp.DataType.Type.INT: {
214            exp.DataType.Type.BIGINT,
215            exp.DataType.Type.DECIMAL,
216            exp.DataType.Type.FLOAT,
217            exp.DataType.Type.DOUBLE,
218        },
219        exp.DataType.Type.SMALLINT: {
220            exp.DataType.Type.INT,
221            exp.DataType.Type.BIGINT,
222            exp.DataType.Type.DECIMAL,
223            exp.DataType.Type.FLOAT,
224            exp.DataType.Type.DOUBLE,
225        },
226        exp.DataType.Type.TINYINT: {
227            exp.DataType.Type.SMALLINT,
228            exp.DataType.Type.INT,
229            exp.DataType.Type.BIGINT,
230            exp.DataType.Type.DECIMAL,
231            exp.DataType.Type.FLOAT,
232            exp.DataType.Type.DOUBLE,
233        },
234        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
235        exp.DataType.Type.TIMESTAMPLTZ: set(),
236        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
237        exp.DataType.Type.TIMESTAMP: {
238            exp.DataType.Type.TIMESTAMPTZ,
239            exp.DataType.Type.TIMESTAMPLTZ,
240        },
241        exp.DataType.Type.DATETIME: {
242            exp.DataType.Type.TIMESTAMP,
243            exp.DataType.Type.TIMESTAMPTZ,
244            exp.DataType.Type.TIMESTAMPLTZ,
245        },
246        exp.DataType.Type.DATE: {
247            exp.DataType.Type.DATETIME,
248            exp.DataType.Type.TIMESTAMP,
249            exp.DataType.Type.TIMESTAMPTZ,
250            exp.DataType.Type.TIMESTAMPLTZ,
251        },
252    }
253
254    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
255
256    def __init__(self, schema=None, annotators=None, coerces_to=None):
257        self.schema = schema
258        self.annotators = annotators or self.ANNOTATORS
259        self.coerces_to = coerces_to or self.COERCES_TO
260
261    def annotate(self, expression):
262        if isinstance(expression, self.TRAVERSABLES):
263            for scope in traverse_scope(expression):
264                selects = {}
265                for name, source in scope.sources.items():
266                    if not isinstance(source, Scope):
267                        continue
268                    if isinstance(source.expression, exp.UDTF):
269                        values = []
270
271                        if isinstance(source.expression, exp.Lateral):
272                            if isinstance(source.expression.this, exp.Explode):
273                                values = [source.expression.this.this]
274                        else:
275                            values = source.expression.expressions[0].expressions
276
277                        if not values:
278                            continue
279
280                        selects[name] = {
281                            alias: column
282                            for alias, column in zip(
283                                source.expression.alias_column_names,
284                                values,
285                            )
286                        }
287                    else:
288                        selects[name] = {
289                            select.alias_or_name: select for select in source.expression.selects
290                        }
291                # First annotate the current scope's column references
292                for col in scope.columns:
293                    if not col.table:
294                        continue
295
296                    source = scope.sources.get(col.table)
297                    if isinstance(source, exp.Table):
298                        col.type = self.schema.get_column_type(source, col)
299                    elif source and col.table in selects and col.name in selects[col.table]:
300                        col.type = selects[col.table][col.name].type
301                # Then (possibly) annotate the remaining expressions in the scope
302                self._maybe_annotate(scope.expression)
303        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
304
305    def _maybe_annotate(self, expression):
306        if expression.type:
307            return expression  # We've already inferred the expression's type
308
309        annotator = self.annotators.get(expression.__class__)
310
311        return (
312            annotator(self, expression)
313            if annotator
314            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
315        )
316
317    def _annotate_args(self, expression):
318        for _, value in expression.iter_expressions():
319            self._maybe_annotate(value)
320
321        return expression
322
323    def _maybe_coerce(self, type1, type2):
324        # We propagate the NULL / UNKNOWN types upwards if found
325        if isinstance(type1, exp.DataType):
326            type1 = type1.this
327        if isinstance(type2, exp.DataType):
328            type2 = type2.this
329
330        if exp.DataType.Type.NULL in (type1, type2):
331            return exp.DataType.Type.NULL
332        if exp.DataType.Type.UNKNOWN in (type1, type2):
333            return exp.DataType.Type.UNKNOWN
334
335        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
336
337    def _annotate_binary(self, expression):
338        self._annotate_args(expression)
339
340        left_type = expression.left.type.this
341        right_type = expression.right.type.this
342
343        if isinstance(expression, exp.Connector):
344            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
345                expression.type = exp.DataType.Type.NULL
346            elif exp.DataType.Type.NULL in (left_type, right_type):
347                expression.type = exp.DataType.build(
348                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
349                )
350            else:
351                expression.type = exp.DataType.Type.BOOLEAN
352        elif isinstance(expression, exp.Predicate):
353            expression.type = exp.DataType.Type.BOOLEAN
354        else:
355            expression.type = self._maybe_coerce(left_type, right_type)
356
357        return expression
358
359    def _annotate_unary(self, expression):
360        self._annotate_args(expression)
361
362        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
363            expression.type = exp.DataType.Type.BOOLEAN
364        else:
365            expression.type = expression.this.type
366
367        return expression
368
369    def _annotate_literal(self, expression):
370        if expression.is_string:
371            expression.type = exp.DataType.Type.VARCHAR
372        elif expression.is_int:
373            expression.type = exp.DataType.Type.INT
374        else:
375            expression.type = exp.DataType.Type.DOUBLE
376
377        return expression
378
379    def _annotate_with_type(self, expression, target_type):
380        expression.type = target_type
381        return self._annotate_args(expression)
382
383    def _annotate_by_args(self, expression, *args, promote=False):
384        self._annotate_args(expression)
385        expressions = []
386        for arg in args:
387            arg_expr = expression.args.get(arg)
388            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
389
390        last_datatype = None
391        for expr in expressions:
392            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
393
394        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
395
396        if promote:
397            if expression.type.this in exp.DataType.INTEGER_TYPES:
398                expression.type = exp.DataType.Type.BIGINT
399            elif expression.type.this in exp.DataType.FLOAT_TYPES:
400                expression.type = exp.DataType.Type.DOUBLE
401
402        return expression
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 8def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 9    """
10    Recursively infer & annotate types in an expression syntax tree against a schema.
11    Assumes that we've already executed the optimizer's qualify_columns step.
12
13    Example:
14        >>> import sqlglot
15        >>> schema = {"y": {"cola": "SMALLINT"}}
16        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
17        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
18        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
19        <Type.DOUBLE: 'DOUBLE'>
20
21    Args:
22        expression (sqlglot.Expression): Expression to annotate.
23        schema (dict|sqlglot.optimizer.Schema): Database schema.
24        annotators (dict): Maps expression type to corresponding annotation function.
25        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
26    Returns:
27        sqlglot.Expression: expression annotated with types
28    """
29
30    schema = ensure_schema(schema)
31
32    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression (sqlglot.Expression): Expression to annotate.
  • schema (dict|sqlglot.optimizer.Schema): Database schema.
  • annotators (dict): Maps expression type to corresponding annotation function.
  • coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:

sqlglot.Expression: expression annotated with types

class TypeAnnotator:
 35class TypeAnnotator:
 36    ANNOTATORS = {
 37        **{
 38            expr_type: lambda self, expr: self._annotate_unary(expr)
 39            for expr_type in subclasses(exp.__name__, exp.Unary)
 40        },
 41        **{
 42            expr_type: lambda self, expr: self._annotate_binary(expr)
 43            for expr_type in subclasses(exp.__name__, exp.Binary)
 44        },
 45        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 47        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 48        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 49        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 51        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 52        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 53        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 54        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 55        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 56            expr, exp.DataType.Type.BIGINT
 57        ),
 58        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 59        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 61        exp.Sum: lambda self, expr: self._annotate_by_args(
 62            expr, "this", "expressions", promote=True
 63        ),
 64        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 65        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 66        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 67        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 68            expr, exp.DataType.Type.DATETIME
 69        ),
 70        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 71            expr, exp.DataType.Type.TIMESTAMP
 72        ),
 73        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 74            expr, exp.DataType.Type.TIMESTAMP
 75        ),
 76        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 78        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 79        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 80            expr, exp.DataType.Type.DATETIME
 81        ),
 82        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 83            expr, exp.DataType.Type.DATETIME
 84        ),
 85        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 87        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 88            expr, exp.DataType.Type.TIMESTAMP
 89        ),
 90        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 91            expr, exp.DataType.Type.TIMESTAMP
 92        ),
 93        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 94        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 96        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 97        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 98            expr, exp.DataType.Type.DATE
 99        ),
100        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
101            expr, exp.DataType.Type.VARCHAR
102        ),
103        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
104        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
105        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
106        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
107        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
108        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
109        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
110        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
111        exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
112        exp.SafeConcat: lambda self, expr: self._annotate_with_type(
113            expr, exp.DataType.Type.VARCHAR
114        ),
115        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
116        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
117            expr, exp.DataType.Type.VARCHAR
118        ),
119        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
120            expr, exp.DataType.Type.VARCHAR
121        ),
122        exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
123        exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
124        exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
125        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
126        exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
127        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
128        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
129        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
130        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
131        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
132        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
133        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
134        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
135        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
136        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
137        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
138        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
139            expr, exp.DataType.Type.DOUBLE
140        ),
141        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
142            expr, exp.DataType.Type.BOOLEAN
143        ),
144        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
146        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
147        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
148        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
149        exp.StrToTime: lambda self, expr: self._annotate_with_type(
150            expr, exp.DataType.Type.TIMESTAMP
151        ),
152        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
153        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
154        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
155        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
156        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
157        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
158            expr, exp.DataType.Type.VARCHAR
159        ),
160        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
161            expr, exp.DataType.Type.DATE
162        ),
163        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
164            expr, exp.DataType.Type.TIMESTAMP
165        ),
166        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
167        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
168            expr, exp.DataType.Type.VARCHAR
169        ),
170        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
171        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
172        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
173        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
174            expr, exp.DataType.Type.TIMESTAMP
175        ),
176        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
177            expr, exp.DataType.Type.VARCHAR
178        ),
179        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
180        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
181        exp.VariancePop: lambda self, expr: self._annotate_with_type(
182            expr, exp.DataType.Type.DOUBLE
183        ),
184        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
185        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
186    }
187
188    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
189    COERCES_TO = {
190        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
191        exp.DataType.Type.TEXT: set(),
192        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
193        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
194        exp.DataType.Type.NCHAR: {
195            exp.DataType.Type.VARCHAR,
196            exp.DataType.Type.NVARCHAR,
197            exp.DataType.Type.TEXT,
198        },
199        exp.DataType.Type.CHAR: {
200            exp.DataType.Type.NCHAR,
201            exp.DataType.Type.VARCHAR,
202            exp.DataType.Type.NVARCHAR,
203            exp.DataType.Type.TEXT,
204        },
205        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
206        exp.DataType.Type.DOUBLE: set(),
207        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
208        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
209        exp.DataType.Type.BIGINT: {
210            exp.DataType.Type.DECIMAL,
211            exp.DataType.Type.FLOAT,
212            exp.DataType.Type.DOUBLE,
213        },
214        exp.DataType.Type.INT: {
215            exp.DataType.Type.BIGINT,
216            exp.DataType.Type.DECIMAL,
217            exp.DataType.Type.FLOAT,
218            exp.DataType.Type.DOUBLE,
219        },
220        exp.DataType.Type.SMALLINT: {
221            exp.DataType.Type.INT,
222            exp.DataType.Type.BIGINT,
223            exp.DataType.Type.DECIMAL,
224            exp.DataType.Type.FLOAT,
225            exp.DataType.Type.DOUBLE,
226        },
227        exp.DataType.Type.TINYINT: {
228            exp.DataType.Type.SMALLINT,
229            exp.DataType.Type.INT,
230            exp.DataType.Type.BIGINT,
231            exp.DataType.Type.DECIMAL,
232            exp.DataType.Type.FLOAT,
233            exp.DataType.Type.DOUBLE,
234        },
235        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
236        exp.DataType.Type.TIMESTAMPLTZ: set(),
237        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
238        exp.DataType.Type.TIMESTAMP: {
239            exp.DataType.Type.TIMESTAMPTZ,
240            exp.DataType.Type.TIMESTAMPLTZ,
241        },
242        exp.DataType.Type.DATETIME: {
243            exp.DataType.Type.TIMESTAMP,
244            exp.DataType.Type.TIMESTAMPTZ,
245            exp.DataType.Type.TIMESTAMPLTZ,
246        },
247        exp.DataType.Type.DATE: {
248            exp.DataType.Type.DATETIME,
249            exp.DataType.Type.TIMESTAMP,
250            exp.DataType.Type.TIMESTAMPTZ,
251            exp.DataType.Type.TIMESTAMPLTZ,
252        },
253    }
254
255    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
256
257    def __init__(self, schema=None, annotators=None, coerces_to=None):
258        self.schema = schema
259        self.annotators = annotators or self.ANNOTATORS
260        self.coerces_to = coerces_to or self.COERCES_TO
261
262    def annotate(self, expression):
263        if isinstance(expression, self.TRAVERSABLES):
264            for scope in traverse_scope(expression):
265                selects = {}
266                for name, source in scope.sources.items():
267                    if not isinstance(source, Scope):
268                        continue
269                    if isinstance(source.expression, exp.UDTF):
270                        values = []
271
272                        if isinstance(source.expression, exp.Lateral):
273                            if isinstance(source.expression.this, exp.Explode):
274                                values = [source.expression.this.this]
275                        else:
276                            values = source.expression.expressions[0].expressions
277
278                        if not values:
279                            continue
280
281                        selects[name] = {
282                            alias: column
283                            for alias, column in zip(
284                                source.expression.alias_column_names,
285                                values,
286                            )
287                        }
288                    else:
289                        selects[name] = {
290                            select.alias_or_name: select for select in source.expression.selects
291                        }
292                # First annotate the current scope's column references
293                for col in scope.columns:
294                    if not col.table:
295                        continue
296
297                    source = scope.sources.get(col.table)
298                    if isinstance(source, exp.Table):
299                        col.type = self.schema.get_column_type(source, col)
300                    elif source and col.table in selects and col.name in selects[col.table]:
301                        col.type = selects[col.table][col.name].type
302                # Then (possibly) annotate the remaining expressions in the scope
303                self._maybe_annotate(scope.expression)
304        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
305
306    def _maybe_annotate(self, expression):
307        if expression.type:
308            return expression  # We've already inferred the expression's type
309
310        annotator = self.annotators.get(expression.__class__)
311
312        return (
313            annotator(self, expression)
314            if annotator
315            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
316        )
317
318    def _annotate_args(self, expression):
319        for _, value in expression.iter_expressions():
320            self._maybe_annotate(value)
321
322        return expression
323
324    def _maybe_coerce(self, type1, type2):
325        # We propagate the NULL / UNKNOWN types upwards if found
326        if isinstance(type1, exp.DataType):
327            type1 = type1.this
328        if isinstance(type2, exp.DataType):
329            type2 = type2.this
330
331        if exp.DataType.Type.NULL in (type1, type2):
332            return exp.DataType.Type.NULL
333        if exp.DataType.Type.UNKNOWN in (type1, type2):
334            return exp.DataType.Type.UNKNOWN
335
336        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
337
338    def _annotate_binary(self, expression):
339        self._annotate_args(expression)
340
341        left_type = expression.left.type.this
342        right_type = expression.right.type.this
343
344        if isinstance(expression, exp.Connector):
345            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
346                expression.type = exp.DataType.Type.NULL
347            elif exp.DataType.Type.NULL in (left_type, right_type):
348                expression.type = exp.DataType.build(
349                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
350                )
351            else:
352                expression.type = exp.DataType.Type.BOOLEAN
353        elif isinstance(expression, exp.Predicate):
354            expression.type = exp.DataType.Type.BOOLEAN
355        else:
356            expression.type = self._maybe_coerce(left_type, right_type)
357
358        return expression
359
360    def _annotate_unary(self, expression):
361        self._annotate_args(expression)
362
363        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
364            expression.type = exp.DataType.Type.BOOLEAN
365        else:
366            expression.type = expression.this.type
367
368        return expression
369
370    def _annotate_literal(self, expression):
371        if expression.is_string:
372            expression.type = exp.DataType.Type.VARCHAR
373        elif expression.is_int:
374            expression.type = exp.DataType.Type.INT
375        else:
376            expression.type = exp.DataType.Type.DOUBLE
377
378        return expression
379
380    def _annotate_with_type(self, expression, target_type):
381        expression.type = target_type
382        return self._annotate_args(expression)
383
384    def _annotate_by_args(self, expression, *args, promote=False):
385        self._annotate_args(expression)
386        expressions = []
387        for arg in args:
388            arg_expr = expression.args.get(arg)
389            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
390
391        last_datatype = None
392        for expr in expressions:
393            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
394
395        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
396
397        if promote:
398            if expression.type.this in exp.DataType.INTEGER_TYPES:
399                expression.type = exp.DataType.Type.BIGINT
400            elif expression.type.this in exp.DataType.FLOAT_TYPES:
401                expression.type = exp.DataType.Type.DOUBLE
402
403        return expression
TypeAnnotator(schema=None, annotators=None, coerces_to=None)
257    def __init__(self, schema=None, annotators=None, coerces_to=None):
258        self.schema = schema
259        self.annotators = annotators or self.ANNOTATORS
260        self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
262    def annotate(self, expression):
263        if isinstance(expression, self.TRAVERSABLES):
264            for scope in traverse_scope(expression):
265                selects = {}
266                for name, source in scope.sources.items():
267                    if not isinstance(source, Scope):
268                        continue
269                    if isinstance(source.expression, exp.UDTF):
270                        values = []
271
272                        if isinstance(source.expression, exp.Lateral):
273                            if isinstance(source.expression.this, exp.Explode):
274                                values = [source.expression.this.this]
275                        else:
276                            values = source.expression.expressions[0].expressions
277
278                        if not values:
279                            continue
280
281                        selects[name] = {
282                            alias: column
283                            for alias, column in zip(
284                                source.expression.alias_column_names,
285                                values,
286                            )
287                        }
288                    else:
289                        selects[name] = {
290                            select.alias_or_name: select for select in source.expression.selects
291                        }
292                # First annotate the current scope's column references
293                for col in scope.columns:
294                    if not col.table:
295                        continue
296
297                    source = scope.sources.get(col.table)
298                    if isinstance(source, exp.Table):
299                        col.type = self.schema.get_column_type(source, col)
300                    elif source and col.table in selects and col.name in selects[col.table]:
301                        col.type = selects[col.table][col.name].type
302                # Then (possibly) annotate the remaining expressions in the scope
303                self._maybe_annotate(scope.expression)
304        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions