sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot._typing import E 7from sqlglot.helper import ensure_list, subclasses 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import Schema, ensure_schema 10 11if t.TYPE_CHECKING: 12 B = t.TypeVar("B", bound=exp.Binary) 13 14 15def annotate_types( 16 expression: E, 17 schema: t.Optional[t.Dict | Schema] = None, 18 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 19 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 20) -> E: 21 """ 22 Infers the types of an expression, annotating its AST accordingly. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"y": {"cola": "SMALLINT"}} 27 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 28 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 29 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 30 <Type.DOUBLE: 'DOUBLE'> 31 32 Args: 33 expression: Expression to annotate. 34 schema: Database schema. 35 annotators: Maps expression type to corresponding annotation function. 36 coerces_to: Maps expression type to set of types that it can be coerced into. 37 38 Returns: 39 The expression annotated with types. 40 """ 41 42 schema = ensure_schema(schema) 43 44 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 45 46 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 48 return lambda self, e: self._annotate_with_type(e, data_type) 49 50 51class _TypeAnnotator(type): 52 def __new__(cls, clsname, bases, attrs): 53 klass = super().__new__(cls, clsname, bases, attrs) 54 55 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 56 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 57 text_precedence = ( 58 exp.DataType.Type.TEXT, 59 exp.DataType.Type.NVARCHAR, 60 exp.DataType.Type.VARCHAR, 61 exp.DataType.Type.NCHAR, 62 exp.DataType.Type.CHAR, 63 ) 64 numeric_precedence = ( 65 exp.DataType.Type.DOUBLE, 66 exp.DataType.Type.FLOAT, 67 exp.DataType.Type.DECIMAL, 68 exp.DataType.Type.BIGINT, 69 exp.DataType.Type.INT, 70 exp.DataType.Type.SMALLINT, 71 exp.DataType.Type.TINYINT, 72 ) 73 timelike_precedence = ( 74 exp.DataType.Type.TIMESTAMPLTZ, 75 exp.DataType.Type.TIMESTAMPTZ, 76 exp.DataType.Type.TIMESTAMP, 77 exp.DataType.Type.DATETIME, 78 exp.DataType.Type.DATE, 79 ) 80 81 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 82 coerces_to = set() 83 for data_type in type_precedence: 84 klass.COERCES_TO[data_type] = coerces_to.copy() 85 coerces_to |= {data_type} 86 87 return klass 88 89 90class TypeAnnotator(metaclass=_TypeAnnotator): 91 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 92 exp.DataType.Type.BIGINT: { 93 exp.ApproxDistinct, 94 exp.ArraySize, 95 exp.Count, 96 exp.Length, 97 }, 98 exp.DataType.Type.BOOLEAN: { 99 exp.Between, 100 exp.Boolean, 101 exp.In, 102 exp.RegexpLike, 103 }, 104 exp.DataType.Type.DATE: { 105 exp.CurrentDate, 106 exp.Date, 107 exp.DateAdd, 108 exp.DateFromParts, 109 exp.DateStrToDate, 110 exp.DateSub, 111 exp.DateTrunc, 112 exp.DiToDate, 113 exp.StrToDate, 114 exp.TimeStrToDate, 115 exp.TsOrDsToDate, 116 }, 117 exp.DataType.Type.DATETIME: { 118 exp.CurrentDatetime, 119 exp.DatetimeAdd, 120 exp.DatetimeSub, 121 }, 122 exp.DataType.Type.DOUBLE: { 123 exp.ApproxQuantile, 124 exp.Avg, 125 exp.Exp, 126 exp.Ln, 127 exp.Log, 128 exp.Log2, 129 exp.Log10, 130 exp.Pow, 131 exp.Quantile, 132 exp.Round, 133 exp.SafeDivide, 134 exp.Sqrt, 135 exp.Stddev, 136 exp.StddevPop, 137 exp.StddevSamp, 138 exp.Variance, 139 exp.VariancePop, 140 }, 141 exp.DataType.Type.INT: { 142 exp.Ceil, 143 exp.DateDiff, 144 exp.DatetimeDiff, 145 exp.Extract, 146 exp.TimestampDiff, 147 exp.TimeDiff, 148 exp.DateToDi, 149 exp.Floor, 150 exp.Levenshtein, 151 exp.StrPosition, 152 exp.TsOrDiToDi, 153 }, 154 exp.DataType.Type.TIMESTAMP: { 155 exp.CurrentTime, 156 exp.CurrentTimestamp, 157 exp.StrToTime, 158 exp.TimeAdd, 159 exp.TimeStrToTime, 160 exp.TimeSub, 161 exp.Timestamp, 162 exp.TimestampAdd, 163 exp.TimestampSub, 164 exp.UnixToTime, 165 }, 166 exp.DataType.Type.TINYINT: { 167 exp.Day, 168 exp.Month, 169 exp.Week, 170 exp.Year, 171 }, 172 exp.DataType.Type.VARCHAR: { 173 exp.ArrayConcat, 174 exp.Concat, 175 exp.ConcatWs, 176 exp.DateToDateStr, 177 exp.GroupConcat, 178 exp.Initcap, 179 exp.Lower, 180 exp.SafeConcat, 181 exp.SafeDPipe, 182 exp.Substring, 183 exp.TimeToStr, 184 exp.TimeToTimeStr, 185 exp.Trim, 186 exp.TsOrDsToDateStr, 187 exp.UnixToStr, 188 exp.UnixToTimeStr, 189 exp.Upper, 190 }, 191 } 192 193 ANNOTATORS: t.Dict = { 194 **{ 195 expr_type: lambda self, e: self._annotate_unary(e) 196 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 197 }, 198 **{ 199 expr_type: lambda self, e: self._annotate_binary(e) 200 for expr_type in subclasses(exp.__name__, exp.Binary) 201 }, 202 **{ 203 expr_type: _annotate_with_type_lambda(data_type) 204 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 205 for expr_type in expressions 206 }, 207 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 208 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 209 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 210 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 211 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 212 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 213 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 214 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 215 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 216 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 217 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 218 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 219 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 220 exp.Literal: lambda self, e: self._annotate_literal(e), 221 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 222 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 223 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 224 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 225 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 226 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 227 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 228 } 229 230 NESTED_TYPES = { 231 exp.DataType.Type.ARRAY, 232 } 233 234 # Specifies what types a given type can be coerced into (autofilled) 235 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 236 237 def __init__( 238 self, 239 schema: Schema, 240 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 241 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 242 ) -> None: 243 self.schema = schema 244 self.annotators = annotators or self.ANNOTATORS 245 self.coerces_to = coerces_to or self.COERCES_TO 246 247 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 248 self._visited: t.Set[int] = set() 249 250 def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: 251 expression.type = target_type 252 self._visited.add(id(expression)) 253 254 def annotate(self, expression: E) -> E: 255 for scope in traverse_scope(expression): 256 selects = {} 257 for name, source in scope.sources.items(): 258 if not isinstance(source, Scope): 259 continue 260 if isinstance(source.expression, exp.UDTF): 261 values = [] 262 263 if isinstance(source.expression, exp.Lateral): 264 if isinstance(source.expression.this, exp.Explode): 265 values = [source.expression.this.this] 266 else: 267 values = source.expression.expressions[0].expressions 268 269 if not values: 270 continue 271 272 selects[name] = { 273 alias: column 274 for alias, column in zip( 275 source.expression.alias_column_names, 276 values, 277 ) 278 } 279 else: 280 selects[name] = { 281 select.alias_or_name: select for select in source.expression.selects 282 } 283 284 # First annotate the current scope's column references 285 for col in scope.columns: 286 if not col.table: 287 continue 288 289 source = scope.sources.get(col.table) 290 if isinstance(source, exp.Table): 291 self._set_type(col, self.schema.get_column_type(source, col)) 292 elif source and col.table in selects and col.name in selects[col.table]: 293 self._set_type(col, selects[col.table][col.name].type) 294 295 # Then (possibly) annotate the remaining expressions in the scope 296 self._maybe_annotate(scope.expression) 297 298 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 299 300 def _maybe_annotate(self, expression: E) -> E: 301 if id(expression) in self._visited: 302 return expression # We've already inferred the expression's type 303 304 annotator = self.annotators.get(expression.__class__) 305 306 return ( 307 annotator(self, expression) 308 if annotator 309 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 310 ) 311 312 def _annotate_args(self, expression: E) -> E: 313 for _, value in expression.iter_expressions(): 314 self._maybe_annotate(value) 315 316 return expression 317 318 def _maybe_coerce( 319 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 320 ) -> exp.DataType | exp.DataType.Type: 321 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 322 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 323 324 # We propagate the NULL / UNKNOWN types upwards if found 325 if exp.DataType.Type.NULL in (type1_value, type2_value): 326 return exp.DataType.Type.NULL 327 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 328 return exp.DataType.Type.UNKNOWN 329 330 if type1_value in self.NESTED_TYPES: 331 return type1 332 if type2_value in self.NESTED_TYPES: 333 return type2 334 335 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 336 337 # Note: the following "no_type_check" decorators were added because mypy was yelling due 338 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 339 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 340 341 @t.no_type_check 342 def _annotate_binary(self, expression: B) -> B: 343 self._annotate_args(expression) 344 345 left_type = expression.left.type.this 346 right_type = expression.right.type.this 347 348 if isinstance(expression, exp.Connector): 349 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 350 self._set_type(expression, exp.DataType.Type.NULL) 351 elif exp.DataType.Type.NULL in (left_type, right_type): 352 self._set_type( 353 expression, 354 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 355 ) 356 else: 357 self._set_type(expression, exp.DataType.Type.BOOLEAN) 358 elif isinstance(expression, exp.Predicate): 359 self._set_type(expression, exp.DataType.Type.BOOLEAN) 360 else: 361 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 362 363 return expression 364 365 @t.no_type_check 366 def _annotate_unary(self, expression: E) -> E: 367 self._annotate_args(expression) 368 369 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 370 self._set_type(expression, exp.DataType.Type.BOOLEAN) 371 else: 372 self._set_type(expression, expression.this.type) 373 374 return expression 375 376 @t.no_type_check 377 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 378 if expression.is_string: 379 self._set_type(expression, exp.DataType.Type.VARCHAR) 380 elif expression.is_int: 381 self._set_type(expression, exp.DataType.Type.INT) 382 else: 383 self._set_type(expression, exp.DataType.Type.DOUBLE) 384 385 return expression 386 387 @t.no_type_check 388 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 389 self._set_type(expression, target_type) 390 return self._annotate_args(expression) 391 392 @t.no_type_check 393 def _annotate_by_args( 394 self, expression: E, *args: str, promote: bool = False, array: bool = False 395 ) -> E: 396 self._annotate_args(expression) 397 398 expressions: t.List[exp.Expression] = [] 399 for arg in args: 400 arg_expr = expression.args.get(arg) 401 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 402 403 last_datatype = None 404 for expr in expressions: 405 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 406 407 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 408 409 if promote: 410 if expression.type.this in exp.DataType.INTEGER_TYPES: 411 self._set_type(expression, exp.DataType.Type.BIGINT) 412 elif expression.type.this in exp.DataType.FLOAT_TYPES: 413 self._set_type(expression, exp.DataType.Type.DOUBLE) 414 415 if array: 416 self._set_type( 417 expression, 418 exp.DataType( 419 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 420 ), 421 ) 422 423 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types( 17 expression: E, 18 schema: t.Optional[t.Dict | Schema] = None, 19 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 20 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 21) -> E: 22 """ 23 Infers the types of an expression, annotating its AST accordingly. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"y": {"cola": "SMALLINT"}} 28 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 29 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 30 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 31 <Type.DOUBLE: 'DOUBLE'> 32 33 Args: 34 expression: Expression to annotate. 35 schema: Database schema. 36 annotators: Maps expression type to corresponding annotation function. 37 coerces_to: Maps expression type to set of types that it can be coerced into. 38 39 Returns: 40 The expression annotated with types. 41 """ 42 43 schema = ensure_schema(schema) 44 45 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
class
TypeAnnotator:
91class TypeAnnotator(metaclass=_TypeAnnotator): 92 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 93 exp.DataType.Type.BIGINT: { 94 exp.ApproxDistinct, 95 exp.ArraySize, 96 exp.Count, 97 exp.Length, 98 }, 99 exp.DataType.Type.BOOLEAN: { 100 exp.Between, 101 exp.Boolean, 102 exp.In, 103 exp.RegexpLike, 104 }, 105 exp.DataType.Type.DATE: { 106 exp.CurrentDate, 107 exp.Date, 108 exp.DateAdd, 109 exp.DateFromParts, 110 exp.DateStrToDate, 111 exp.DateSub, 112 exp.DateTrunc, 113 exp.DiToDate, 114 exp.StrToDate, 115 exp.TimeStrToDate, 116 exp.TsOrDsToDate, 117 }, 118 exp.DataType.Type.DATETIME: { 119 exp.CurrentDatetime, 120 exp.DatetimeAdd, 121 exp.DatetimeSub, 122 }, 123 exp.DataType.Type.DOUBLE: { 124 exp.ApproxQuantile, 125 exp.Avg, 126 exp.Exp, 127 exp.Ln, 128 exp.Log, 129 exp.Log2, 130 exp.Log10, 131 exp.Pow, 132 exp.Quantile, 133 exp.Round, 134 exp.SafeDivide, 135 exp.Sqrt, 136 exp.Stddev, 137 exp.StddevPop, 138 exp.StddevSamp, 139 exp.Variance, 140 exp.VariancePop, 141 }, 142 exp.DataType.Type.INT: { 143 exp.Ceil, 144 exp.DateDiff, 145 exp.DatetimeDiff, 146 exp.Extract, 147 exp.TimestampDiff, 148 exp.TimeDiff, 149 exp.DateToDi, 150 exp.Floor, 151 exp.Levenshtein, 152 exp.StrPosition, 153 exp.TsOrDiToDi, 154 }, 155 exp.DataType.Type.TIMESTAMP: { 156 exp.CurrentTime, 157 exp.CurrentTimestamp, 158 exp.StrToTime, 159 exp.TimeAdd, 160 exp.TimeStrToTime, 161 exp.TimeSub, 162 exp.Timestamp, 163 exp.TimestampAdd, 164 exp.TimestampSub, 165 exp.UnixToTime, 166 }, 167 exp.DataType.Type.TINYINT: { 168 exp.Day, 169 exp.Month, 170 exp.Week, 171 exp.Year, 172 }, 173 exp.DataType.Type.VARCHAR: { 174 exp.ArrayConcat, 175 exp.Concat, 176 exp.ConcatWs, 177 exp.DateToDateStr, 178 exp.GroupConcat, 179 exp.Initcap, 180 exp.Lower, 181 exp.SafeConcat, 182 exp.SafeDPipe, 183 exp.Substring, 184 exp.TimeToStr, 185 exp.TimeToTimeStr, 186 exp.Trim, 187 exp.TsOrDsToDateStr, 188 exp.UnixToStr, 189 exp.UnixToTimeStr, 190 exp.Upper, 191 }, 192 } 193 194 ANNOTATORS: t.Dict = { 195 **{ 196 expr_type: lambda self, e: self._annotate_unary(e) 197 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 198 }, 199 **{ 200 expr_type: lambda self, e: self._annotate_binary(e) 201 for expr_type in subclasses(exp.__name__, exp.Binary) 202 }, 203 **{ 204 expr_type: _annotate_with_type_lambda(data_type) 205 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 206 for expr_type in expressions 207 }, 208 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 209 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 210 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 211 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 212 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 213 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 214 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 215 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 216 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 217 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 218 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 219 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 220 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 221 exp.Literal: lambda self, e: self._annotate_literal(e), 222 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 223 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 224 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 225 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 226 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 227 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 228 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 229 } 230 231 NESTED_TYPES = { 232 exp.DataType.Type.ARRAY, 233 } 234 235 # Specifies what types a given type can be coerced into (autofilled) 236 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 237 238 def __init__( 239 self, 240 schema: Schema, 241 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 242 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 243 ) -> None: 244 self.schema = schema 245 self.annotators = annotators or self.ANNOTATORS 246 self.coerces_to = coerces_to or self.COERCES_TO 247 248 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 249 self._visited: t.Set[int] = set() 250 251 def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: 252 expression.type = target_type 253 self._visited.add(id(expression)) 254 255 def annotate(self, expression: E) -> E: 256 for scope in traverse_scope(expression): 257 selects = {} 258 for name, source in scope.sources.items(): 259 if not isinstance(source, Scope): 260 continue 261 if isinstance(source.expression, exp.UDTF): 262 values = [] 263 264 if isinstance(source.expression, exp.Lateral): 265 if isinstance(source.expression.this, exp.Explode): 266 values = [source.expression.this.this] 267 else: 268 values = source.expression.expressions[0].expressions 269 270 if not values: 271 continue 272 273 selects[name] = { 274 alias: column 275 for alias, column in zip( 276 source.expression.alias_column_names, 277 values, 278 ) 279 } 280 else: 281 selects[name] = { 282 select.alias_or_name: select for select in source.expression.selects 283 } 284 285 # First annotate the current scope's column references 286 for col in scope.columns: 287 if not col.table: 288 continue 289 290 source = scope.sources.get(col.table) 291 if isinstance(source, exp.Table): 292 self._set_type(col, self.schema.get_column_type(source, col)) 293 elif source and col.table in selects and col.name in selects[col.table]: 294 self._set_type(col, selects[col.table][col.name].type) 295 296 # Then (possibly) annotate the remaining expressions in the scope 297 self._maybe_annotate(scope.expression) 298 299 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 300 301 def _maybe_annotate(self, expression: E) -> E: 302 if id(expression) in self._visited: 303 return expression # We've already inferred the expression's type 304 305 annotator = self.annotators.get(expression.__class__) 306 307 return ( 308 annotator(self, expression) 309 if annotator 310 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 311 ) 312 313 def _annotate_args(self, expression: E) -> E: 314 for _, value in expression.iter_expressions(): 315 self._maybe_annotate(value) 316 317 return expression 318 319 def _maybe_coerce( 320 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 321 ) -> exp.DataType | exp.DataType.Type: 322 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 323 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 324 325 # We propagate the NULL / UNKNOWN types upwards if found 326 if exp.DataType.Type.NULL in (type1_value, type2_value): 327 return exp.DataType.Type.NULL 328 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 329 return exp.DataType.Type.UNKNOWN 330 331 if type1_value in self.NESTED_TYPES: 332 return type1 333 if type2_value in self.NESTED_TYPES: 334 return type2 335 336 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 337 338 # Note: the following "no_type_check" decorators were added because mypy was yelling due 339 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 340 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 341 342 @t.no_type_check 343 def _annotate_binary(self, expression: B) -> B: 344 self._annotate_args(expression) 345 346 left_type = expression.left.type.this 347 right_type = expression.right.type.this 348 349 if isinstance(expression, exp.Connector): 350 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 351 self._set_type(expression, exp.DataType.Type.NULL) 352 elif exp.DataType.Type.NULL in (left_type, right_type): 353 self._set_type( 354 expression, 355 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 356 ) 357 else: 358 self._set_type(expression, exp.DataType.Type.BOOLEAN) 359 elif isinstance(expression, exp.Predicate): 360 self._set_type(expression, exp.DataType.Type.BOOLEAN) 361 else: 362 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 363 364 return expression 365 366 @t.no_type_check 367 def _annotate_unary(self, expression: E) -> E: 368 self._annotate_args(expression) 369 370 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 371 self._set_type(expression, exp.DataType.Type.BOOLEAN) 372 else: 373 self._set_type(expression, expression.this.type) 374 375 return expression 376 377 @t.no_type_check 378 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 379 if expression.is_string: 380 self._set_type(expression, exp.DataType.Type.VARCHAR) 381 elif expression.is_int: 382 self._set_type(expression, exp.DataType.Type.INT) 383 else: 384 self._set_type(expression, exp.DataType.Type.DOUBLE) 385 386 return expression 387 388 @t.no_type_check 389 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 390 self._set_type(expression, target_type) 391 return self._annotate_args(expression) 392 393 @t.no_type_check 394 def _annotate_by_args( 395 self, expression: E, *args: str, promote: bool = False, array: bool = False 396 ) -> E: 397 self._annotate_args(expression) 398 399 expressions: t.List[exp.Expression] = [] 400 for arg in args: 401 arg_expr = expression.args.get(arg) 402 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 403 404 last_datatype = None 405 for expr in expressions: 406 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 407 408 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 409 410 if promote: 411 if expression.type.this in exp.DataType.INTEGER_TYPES: 412 self._set_type(expression, exp.DataType.Type.BIGINT) 413 elif expression.type.this in exp.DataType.FLOAT_TYPES: 414 self._set_type(expression, exp.DataType.Type.DOUBLE) 415 416 if array: 417 self._set_type( 418 expression, 419 exp.DataType( 420 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 421 ), 422 ) 423 424 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
238 def __init__( 239 self, 240 schema: Schema, 241 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 242 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 243 ) -> None: 244 self.schema = schema 245 self.annotators = annotators or self.ANNOTATORS 246 self.coerces_to = coerces_to or self.COERCES_TO 247 248 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 249 self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ArraySize'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.Boolean'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateStrToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Sqrt'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Extract'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.Timestamp'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Week'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Trim'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.CHAR: 'CHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TINYINT: 'TINYINT'>: {<Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
def
annotate(self, expression: ~E) -> ~E:
255 def annotate(self, expression: E) -> E: 256 for scope in traverse_scope(expression): 257 selects = {} 258 for name, source in scope.sources.items(): 259 if not isinstance(source, Scope): 260 continue 261 if isinstance(source.expression, exp.UDTF): 262 values = [] 263 264 if isinstance(source.expression, exp.Lateral): 265 if isinstance(source.expression.this, exp.Explode): 266 values = [source.expression.this.this] 267 else: 268 values = source.expression.expressions[0].expressions 269 270 if not values: 271 continue 272 273 selects[name] = { 274 alias: column 275 for alias, column in zip( 276 source.expression.alias_column_names, 277 values, 278 ) 279 } 280 else: 281 selects[name] = { 282 select.alias_or_name: select for select in source.expression.selects 283 } 284 285 # First annotate the current scope's column references 286 for col in scope.columns: 287 if not col.table: 288 continue 289 290 source = scope.sources.get(col.table) 291 if isinstance(source, exp.Table): 292 self._set_type(col, self.schema.get_column_type(source, col)) 293 elif source and col.table in selects and col.name in selects[col.table]: 294 self._set_type(col, selects[col.table][col.name].type) 295 296 # Then (possibly) annotate the remaining expressions in the scope 297 self._maybe_annotate(scope.expression) 298 299 return self._maybe_annotate(expression) # This takes care of non-traversable expressions