sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import functools 4import typing as t 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.helper import ( 9 ensure_list, 10 is_date_unit, 11 is_iso_date, 12 is_iso_datetime, 13 seq_get, 14 subclasses, 15) 16from sqlglot.optimizer.scope import Scope, traverse_scope 17from sqlglot.schema import Schema, ensure_schema 18 19if t.TYPE_CHECKING: 20 B = t.TypeVar("B", bound=exp.Binary) 21 22 BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] 23 BinaryCoercions = t.Dict[ 24 t.Tuple[exp.DataType.Type, exp.DataType.Type], 25 BinaryCoercionFunc, 26 ] 27 28 29def annotate_types( 30 expression: E, 31 schema: t.Optional[t.Dict | Schema] = None, 32 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 33 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 34) -> E: 35 """ 36 Infers the types of an expression, annotating its AST accordingly. 37 38 Example: 39 >>> import sqlglot 40 >>> schema = {"y": {"cola": "SMALLINT"}} 41 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 42 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 43 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 44 <Type.DOUBLE: 'DOUBLE'> 45 46 Args: 47 expression: Expression to annotate. 48 schema: Database schema. 49 annotators: Maps expression type to corresponding annotation function. 50 coerces_to: Maps expression type to set of types that it can be coerced into. 51 52 Returns: 53 The expression annotated with types. 54 """ 55 56 schema = ensure_schema(schema) 57 58 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 59 60 61def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 62 return lambda self, e: self._annotate_with_type(e, data_type) 63 64 65def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: 66 date_text = l.name 67 is_iso_date_ = is_iso_date(date_text) 68 69 if is_iso_date_ and is_date_unit(unit): 70 return exp.DataType.Type.DATE 71 72 # An ISO date is also an ISO datetime, but not vice versa 73 if is_iso_date_ or is_iso_datetime(date_text): 74 return exp.DataType.Type.DATETIME 75 76 return exp.DataType.Type.UNKNOWN 77 78 79def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: 80 if not is_date_unit(unit): 81 return exp.DataType.Type.DATETIME 82 return l.type.this if l.type else exp.DataType.Type.UNKNOWN 83 84 85def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: 86 @functools.wraps(func) 87 def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 88 return func(r, l) 89 90 return _swapped 91 92 93def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: 94 return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} 95 96 97class _TypeAnnotator(type): 98 def __new__(cls, clsname, bases, attrs): 99 klass = super().__new__(cls, clsname, bases, attrs) 100 101 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 102 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 103 text_precedence = ( 104 exp.DataType.Type.TEXT, 105 exp.DataType.Type.NVARCHAR, 106 exp.DataType.Type.VARCHAR, 107 exp.DataType.Type.NCHAR, 108 exp.DataType.Type.CHAR, 109 ) 110 numeric_precedence = ( 111 exp.DataType.Type.DOUBLE, 112 exp.DataType.Type.FLOAT, 113 exp.DataType.Type.DECIMAL, 114 exp.DataType.Type.BIGINT, 115 exp.DataType.Type.INT, 116 exp.DataType.Type.SMALLINT, 117 exp.DataType.Type.TINYINT, 118 ) 119 timelike_precedence = ( 120 exp.DataType.Type.TIMESTAMPLTZ, 121 exp.DataType.Type.TIMESTAMPTZ, 122 exp.DataType.Type.TIMESTAMP, 123 exp.DataType.Type.DATETIME, 124 exp.DataType.Type.DATE, 125 ) 126 127 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 128 coerces_to = set() 129 for data_type in type_precedence: 130 klass.COERCES_TO[data_type] = coerces_to.copy() 131 coerces_to |= {data_type} 132 133 return klass 134 135 136class TypeAnnotator(metaclass=_TypeAnnotator): 137 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 138 exp.DataType.Type.BIGINT: { 139 exp.ApproxDistinct, 140 exp.ArraySize, 141 exp.Count, 142 exp.Length, 143 }, 144 exp.DataType.Type.BOOLEAN: { 145 exp.Between, 146 exp.Boolean, 147 exp.In, 148 exp.RegexpLike, 149 }, 150 exp.DataType.Type.DATE: { 151 exp.CurrentDate, 152 exp.Date, 153 exp.DateFromParts, 154 exp.DateStrToDate, 155 exp.DiToDate, 156 exp.StrToDate, 157 exp.TimeStrToDate, 158 exp.TsOrDsToDate, 159 }, 160 exp.DataType.Type.DATETIME: { 161 exp.CurrentDatetime, 162 exp.DatetimeAdd, 163 exp.DatetimeSub, 164 }, 165 exp.DataType.Type.DOUBLE: { 166 exp.ApproxQuantile, 167 exp.Avg, 168 exp.Exp, 169 exp.Ln, 170 exp.Log, 171 exp.Log2, 172 exp.Log10, 173 exp.Pow, 174 exp.Quantile, 175 exp.Round, 176 exp.SafeDivide, 177 exp.Sqrt, 178 exp.Stddev, 179 exp.StddevPop, 180 exp.StddevSamp, 181 exp.Variance, 182 exp.VariancePop, 183 }, 184 exp.DataType.Type.INT: { 185 exp.Ceil, 186 exp.DatetimeDiff, 187 exp.DateDiff, 188 exp.Extract, 189 exp.TimestampDiff, 190 exp.TimeDiff, 191 exp.DateToDi, 192 exp.Floor, 193 exp.Levenshtein, 194 exp.StrPosition, 195 exp.TsOrDiToDi, 196 }, 197 exp.DataType.Type.TIMESTAMP: { 198 exp.CurrentTime, 199 exp.CurrentTimestamp, 200 exp.StrToTime, 201 exp.TimeAdd, 202 exp.TimeStrToTime, 203 exp.TimeSub, 204 exp.Timestamp, 205 exp.TimestampAdd, 206 exp.TimestampSub, 207 exp.UnixToTime, 208 }, 209 exp.DataType.Type.TINYINT: { 210 exp.Day, 211 exp.Month, 212 exp.Week, 213 exp.Year, 214 }, 215 exp.DataType.Type.VARCHAR: { 216 exp.ArrayConcat, 217 exp.Concat, 218 exp.ConcatWs, 219 exp.DateToDateStr, 220 exp.GroupConcat, 221 exp.Initcap, 222 exp.Lower, 223 exp.SafeConcat, 224 exp.SafeDPipe, 225 exp.Substring, 226 exp.TimeToStr, 227 exp.TimeToTimeStr, 228 exp.Trim, 229 exp.TsOrDsToDateStr, 230 exp.UnixToStr, 231 exp.UnixToTimeStr, 232 exp.Upper, 233 }, 234 } 235 236 ANNOTATORS: t.Dict = { 237 **{ 238 expr_type: lambda self, e: self._annotate_unary(e) 239 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 240 }, 241 **{ 242 expr_type: lambda self, e: self._annotate_binary(e) 243 for expr_type in subclasses(exp.__name__, exp.Binary) 244 }, 245 **{ 246 expr_type: _annotate_with_type_lambda(data_type) 247 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 248 for expr_type in expressions 249 }, 250 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 251 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 252 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 253 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 254 exp.Bracket: lambda self, e: self._annotate_bracket(e), 255 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 256 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 257 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 258 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 259 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 260 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 261 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 262 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 263 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 264 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 265 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 266 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 267 exp.Literal: lambda self, e: self._annotate_literal(e), 268 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 269 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 270 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 271 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 272 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 273 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 274 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 275 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 276 } 277 278 NESTED_TYPES = { 279 exp.DataType.Type.ARRAY, 280 } 281 282 # Specifies what types a given type can be coerced into (autofilled) 283 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 284 285 # Coercion functions for binary operations. 286 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 287 BINARY_COERCIONS: BinaryCoercions = { 288 **swap_all( 289 { 290 (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( 291 l, r.args.get("unit") 292 ) 293 for t in exp.DataType.TEXT_TYPES 294 } 295 ), 296 **swap_all( 297 { 298 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( 299 l, r.args.get("unit") 300 ), 301 } 302 ), 303 } 304 305 def __init__( 306 self, 307 schema: Schema, 308 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 309 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 310 binary_coercions: t.Optional[BinaryCoercions] = None, 311 ) -> None: 312 self.schema = schema 313 self.annotators = annotators or self.ANNOTATORS 314 self.coerces_to = coerces_to or self.COERCES_TO 315 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 316 317 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 318 self._visited: t.Set[int] = set() 319 320 def _set_type( 321 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 322 ) -> None: 323 expression.type = target_type # type: ignore 324 self._visited.add(id(expression)) 325 326 def annotate(self, expression: E) -> E: 327 for scope in traverse_scope(expression): 328 selects = {} 329 for name, source in scope.sources.items(): 330 if not isinstance(source, Scope): 331 continue 332 if isinstance(source.expression, exp.UDTF): 333 values = [] 334 335 if isinstance(source.expression, exp.Lateral): 336 if isinstance(source.expression.this, exp.Explode): 337 values = [source.expression.this.this] 338 else: 339 values = source.expression.expressions[0].expressions 340 341 if not values: 342 continue 343 344 selects[name] = { 345 alias: column 346 for alias, column in zip( 347 source.expression.alias_column_names, 348 values, 349 ) 350 } 351 else: 352 selects[name] = { 353 select.alias_or_name: select for select in source.expression.selects 354 } 355 356 # First annotate the current scope's column references 357 for col in scope.columns: 358 if not col.table: 359 continue 360 361 source = scope.sources.get(col.table) 362 if isinstance(source, exp.Table): 363 self._set_type(col, self.schema.get_column_type(source, col)) 364 elif source and col.table in selects and col.name in selects[col.table]: 365 self._set_type(col, selects[col.table][col.name].type) 366 367 # Then (possibly) annotate the remaining expressions in the scope 368 self._maybe_annotate(scope.expression) 369 370 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 371 372 def _maybe_annotate(self, expression: E) -> E: 373 if id(expression) in self._visited: 374 return expression # We've already inferred the expression's type 375 376 annotator = self.annotators.get(expression.__class__) 377 378 return ( 379 annotator(self, expression) 380 if annotator 381 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 382 ) 383 384 def _annotate_args(self, expression: E) -> E: 385 for _, value in expression.iter_expressions(): 386 self._maybe_annotate(value) 387 388 return expression 389 390 def _maybe_coerce( 391 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 392 ) -> exp.DataType | exp.DataType.Type: 393 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 394 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 395 396 # We propagate the NULL / UNKNOWN types upwards if found 397 if exp.DataType.Type.NULL in (type1_value, type2_value): 398 return exp.DataType.Type.NULL 399 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 400 return exp.DataType.Type.UNKNOWN 401 402 if type1_value in self.NESTED_TYPES: 403 return type1 404 if type2_value in self.NESTED_TYPES: 405 return type2 406 407 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 408 409 # Note: the following "no_type_check" decorators were added because mypy was yelling due 410 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 411 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 412 413 @t.no_type_check 414 def _annotate_binary(self, expression: B) -> B: 415 self._annotate_args(expression) 416 417 left, right = expression.left, expression.right 418 left_type, right_type = left.type.this, right.type.this 419 420 if isinstance(expression, exp.Connector): 421 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 422 self._set_type(expression, exp.DataType.Type.NULL) 423 elif exp.DataType.Type.NULL in (left_type, right_type): 424 self._set_type( 425 expression, 426 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 427 ) 428 else: 429 self._set_type(expression, exp.DataType.Type.BOOLEAN) 430 elif isinstance(expression, exp.Predicate): 431 self._set_type(expression, exp.DataType.Type.BOOLEAN) 432 elif (left_type, right_type) in self.binary_coercions: 433 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 434 else: 435 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 436 437 return expression 438 439 @t.no_type_check 440 def _annotate_unary(self, expression: E) -> E: 441 self._annotate_args(expression) 442 443 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 444 self._set_type(expression, exp.DataType.Type.BOOLEAN) 445 else: 446 self._set_type(expression, expression.this.type) 447 448 return expression 449 450 @t.no_type_check 451 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 452 if expression.is_string: 453 self._set_type(expression, exp.DataType.Type.VARCHAR) 454 elif expression.is_int: 455 self._set_type(expression, exp.DataType.Type.INT) 456 else: 457 self._set_type(expression, exp.DataType.Type.DOUBLE) 458 459 return expression 460 461 @t.no_type_check 462 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 463 self._set_type(expression, target_type) 464 return self._annotate_args(expression) 465 466 @t.no_type_check 467 def _annotate_by_args( 468 self, expression: E, *args: str, promote: bool = False, array: bool = False 469 ) -> E: 470 self._annotate_args(expression) 471 472 expressions: t.List[exp.Expression] = [] 473 for arg in args: 474 arg_expr = expression.args.get(arg) 475 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 476 477 last_datatype = None 478 for expr in expressions: 479 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 480 481 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 482 483 if promote: 484 if expression.type.this in exp.DataType.INTEGER_TYPES: 485 self._set_type(expression, exp.DataType.Type.BIGINT) 486 elif expression.type.this in exp.DataType.FLOAT_TYPES: 487 self._set_type(expression, exp.DataType.Type.DOUBLE) 488 489 if array: 490 self._set_type( 491 expression, 492 exp.DataType( 493 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 494 ), 495 ) 496 497 return expression 498 499 def _annotate_timeunit( 500 self, expression: exp.TimeUnit | exp.DateTrunc 501 ) -> exp.TimeUnit | exp.DateTrunc: 502 self._annotate_args(expression) 503 504 if expression.this.type.this in exp.DataType.TEXT_TYPES: 505 datatype = _coerce_date_literal(expression.this, expression.unit) 506 elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: 507 datatype = _coerce_date(expression.this, expression.unit) 508 else: 509 datatype = exp.DataType.Type.UNKNOWN 510 511 self._set_type(expression, datatype) 512 return expression 513 514 def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: 515 self._annotate_args(expression) 516 517 bracket_arg = expression.expressions[0] 518 this = expression.this 519 520 if isinstance(bracket_arg, exp.Slice): 521 self._set_type(expression, this.type) 522 elif this.type.is_type(exp.DataType.Type.ARRAY): 523 contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN 524 self._set_type(expression, contained_type) 525 elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: 526 index = this.keys.index(bracket_arg) 527 value = seq_get(this.values, index) 528 value_type = value.type if value else exp.DataType.Type.UNKNOWN 529 self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) 530 else: 531 self._set_type(expression, exp.DataType.Type.UNKNOWN) 532 533 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
30def annotate_types( 31 expression: E, 32 schema: t.Optional[t.Dict | Schema] = None, 33 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 34 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 35) -> E: 36 """ 37 Infers the types of an expression, annotating its AST accordingly. 38 39 Example: 40 >>> import sqlglot 41 >>> schema = {"y": {"cola": "SMALLINT"}} 42 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 43 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 44 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 45 <Type.DOUBLE: 'DOUBLE'> 46 47 Args: 48 expression: Expression to annotate. 49 schema: Database schema. 50 annotators: Maps expression type to corresponding annotation function. 51 coerces_to: Maps expression type to set of types that it can be coerced into. 52 53 Returns: 54 The expression annotated with types. 55 """ 56 57 schema = ensure_schema(schema) 58 59 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
def
swap_args( func: Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]) -> Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]:
def
swap_all( coercions: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]) -> Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]:
class
TypeAnnotator:
137class TypeAnnotator(metaclass=_TypeAnnotator): 138 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 139 exp.DataType.Type.BIGINT: { 140 exp.ApproxDistinct, 141 exp.ArraySize, 142 exp.Count, 143 exp.Length, 144 }, 145 exp.DataType.Type.BOOLEAN: { 146 exp.Between, 147 exp.Boolean, 148 exp.In, 149 exp.RegexpLike, 150 }, 151 exp.DataType.Type.DATE: { 152 exp.CurrentDate, 153 exp.Date, 154 exp.DateFromParts, 155 exp.DateStrToDate, 156 exp.DiToDate, 157 exp.StrToDate, 158 exp.TimeStrToDate, 159 exp.TsOrDsToDate, 160 }, 161 exp.DataType.Type.DATETIME: { 162 exp.CurrentDatetime, 163 exp.DatetimeAdd, 164 exp.DatetimeSub, 165 }, 166 exp.DataType.Type.DOUBLE: { 167 exp.ApproxQuantile, 168 exp.Avg, 169 exp.Exp, 170 exp.Ln, 171 exp.Log, 172 exp.Log2, 173 exp.Log10, 174 exp.Pow, 175 exp.Quantile, 176 exp.Round, 177 exp.SafeDivide, 178 exp.Sqrt, 179 exp.Stddev, 180 exp.StddevPop, 181 exp.StddevSamp, 182 exp.Variance, 183 exp.VariancePop, 184 }, 185 exp.DataType.Type.INT: { 186 exp.Ceil, 187 exp.DatetimeDiff, 188 exp.DateDiff, 189 exp.Extract, 190 exp.TimestampDiff, 191 exp.TimeDiff, 192 exp.DateToDi, 193 exp.Floor, 194 exp.Levenshtein, 195 exp.StrPosition, 196 exp.TsOrDiToDi, 197 }, 198 exp.DataType.Type.TIMESTAMP: { 199 exp.CurrentTime, 200 exp.CurrentTimestamp, 201 exp.StrToTime, 202 exp.TimeAdd, 203 exp.TimeStrToTime, 204 exp.TimeSub, 205 exp.Timestamp, 206 exp.TimestampAdd, 207 exp.TimestampSub, 208 exp.UnixToTime, 209 }, 210 exp.DataType.Type.TINYINT: { 211 exp.Day, 212 exp.Month, 213 exp.Week, 214 exp.Year, 215 }, 216 exp.DataType.Type.VARCHAR: { 217 exp.ArrayConcat, 218 exp.Concat, 219 exp.ConcatWs, 220 exp.DateToDateStr, 221 exp.GroupConcat, 222 exp.Initcap, 223 exp.Lower, 224 exp.SafeConcat, 225 exp.SafeDPipe, 226 exp.Substring, 227 exp.TimeToStr, 228 exp.TimeToTimeStr, 229 exp.Trim, 230 exp.TsOrDsToDateStr, 231 exp.UnixToStr, 232 exp.UnixToTimeStr, 233 exp.Upper, 234 }, 235 } 236 237 ANNOTATORS: t.Dict = { 238 **{ 239 expr_type: lambda self, e: self._annotate_unary(e) 240 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 241 }, 242 **{ 243 expr_type: lambda self, e: self._annotate_binary(e) 244 for expr_type in subclasses(exp.__name__, exp.Binary) 245 }, 246 **{ 247 expr_type: _annotate_with_type_lambda(data_type) 248 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 249 for expr_type in expressions 250 }, 251 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 252 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 253 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 254 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 255 exp.Bracket: lambda self, e: self._annotate_bracket(e), 256 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 257 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 258 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 259 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 260 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 261 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 262 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 263 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 264 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 265 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 266 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 267 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 268 exp.Literal: lambda self, e: self._annotate_literal(e), 269 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 270 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 271 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 272 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 273 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 274 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 275 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 276 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 277 } 278 279 NESTED_TYPES = { 280 exp.DataType.Type.ARRAY, 281 } 282 283 # Specifies what types a given type can be coerced into (autofilled) 284 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 285 286 # Coercion functions for binary operations. 287 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 288 BINARY_COERCIONS: BinaryCoercions = { 289 **swap_all( 290 { 291 (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( 292 l, r.args.get("unit") 293 ) 294 for t in exp.DataType.TEXT_TYPES 295 } 296 ), 297 **swap_all( 298 { 299 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( 300 l, r.args.get("unit") 301 ), 302 } 303 ), 304 } 305 306 def __init__( 307 self, 308 schema: Schema, 309 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 310 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 311 binary_coercions: t.Optional[BinaryCoercions] = None, 312 ) -> None: 313 self.schema = schema 314 self.annotators = annotators or self.ANNOTATORS 315 self.coerces_to = coerces_to or self.COERCES_TO 316 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 317 318 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 319 self._visited: t.Set[int] = set() 320 321 def _set_type( 322 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 323 ) -> None: 324 expression.type = target_type # type: ignore 325 self._visited.add(id(expression)) 326 327 def annotate(self, expression: E) -> E: 328 for scope in traverse_scope(expression): 329 selects = {} 330 for name, source in scope.sources.items(): 331 if not isinstance(source, Scope): 332 continue 333 if isinstance(source.expression, exp.UDTF): 334 values = [] 335 336 if isinstance(source.expression, exp.Lateral): 337 if isinstance(source.expression.this, exp.Explode): 338 values = [source.expression.this.this] 339 else: 340 values = source.expression.expressions[0].expressions 341 342 if not values: 343 continue 344 345 selects[name] = { 346 alias: column 347 for alias, column in zip( 348 source.expression.alias_column_names, 349 values, 350 ) 351 } 352 else: 353 selects[name] = { 354 select.alias_or_name: select for select in source.expression.selects 355 } 356 357 # First annotate the current scope's column references 358 for col in scope.columns: 359 if not col.table: 360 continue 361 362 source = scope.sources.get(col.table) 363 if isinstance(source, exp.Table): 364 self._set_type(col, self.schema.get_column_type(source, col)) 365 elif source and col.table in selects and col.name in selects[col.table]: 366 self._set_type(col, selects[col.table][col.name].type) 367 368 # Then (possibly) annotate the remaining expressions in the scope 369 self._maybe_annotate(scope.expression) 370 371 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 372 373 def _maybe_annotate(self, expression: E) -> E: 374 if id(expression) in self._visited: 375 return expression # We've already inferred the expression's type 376 377 annotator = self.annotators.get(expression.__class__) 378 379 return ( 380 annotator(self, expression) 381 if annotator 382 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 383 ) 384 385 def _annotate_args(self, expression: E) -> E: 386 for _, value in expression.iter_expressions(): 387 self._maybe_annotate(value) 388 389 return expression 390 391 def _maybe_coerce( 392 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 393 ) -> exp.DataType | exp.DataType.Type: 394 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 395 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 396 397 # We propagate the NULL / UNKNOWN types upwards if found 398 if exp.DataType.Type.NULL in (type1_value, type2_value): 399 return exp.DataType.Type.NULL 400 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 401 return exp.DataType.Type.UNKNOWN 402 403 if type1_value in self.NESTED_TYPES: 404 return type1 405 if type2_value in self.NESTED_TYPES: 406 return type2 407 408 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 409 410 # Note: the following "no_type_check" decorators were added because mypy was yelling due 411 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 412 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 413 414 @t.no_type_check 415 def _annotate_binary(self, expression: B) -> B: 416 self._annotate_args(expression) 417 418 left, right = expression.left, expression.right 419 left_type, right_type = left.type.this, right.type.this 420 421 if isinstance(expression, exp.Connector): 422 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 423 self._set_type(expression, exp.DataType.Type.NULL) 424 elif exp.DataType.Type.NULL in (left_type, right_type): 425 self._set_type( 426 expression, 427 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 428 ) 429 else: 430 self._set_type(expression, exp.DataType.Type.BOOLEAN) 431 elif isinstance(expression, exp.Predicate): 432 self._set_type(expression, exp.DataType.Type.BOOLEAN) 433 elif (left_type, right_type) in self.binary_coercions: 434 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 435 else: 436 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 437 438 return expression 439 440 @t.no_type_check 441 def _annotate_unary(self, expression: E) -> E: 442 self._annotate_args(expression) 443 444 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 445 self._set_type(expression, exp.DataType.Type.BOOLEAN) 446 else: 447 self._set_type(expression, expression.this.type) 448 449 return expression 450 451 @t.no_type_check 452 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 453 if expression.is_string: 454 self._set_type(expression, exp.DataType.Type.VARCHAR) 455 elif expression.is_int: 456 self._set_type(expression, exp.DataType.Type.INT) 457 else: 458 self._set_type(expression, exp.DataType.Type.DOUBLE) 459 460 return expression 461 462 @t.no_type_check 463 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 464 self._set_type(expression, target_type) 465 return self._annotate_args(expression) 466 467 @t.no_type_check 468 def _annotate_by_args( 469 self, expression: E, *args: str, promote: bool = False, array: bool = False 470 ) -> E: 471 self._annotate_args(expression) 472 473 expressions: t.List[exp.Expression] = [] 474 for arg in args: 475 arg_expr = expression.args.get(arg) 476 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 477 478 last_datatype = None 479 for expr in expressions: 480 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 481 482 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 483 484 if promote: 485 if expression.type.this in exp.DataType.INTEGER_TYPES: 486 self._set_type(expression, exp.DataType.Type.BIGINT) 487 elif expression.type.this in exp.DataType.FLOAT_TYPES: 488 self._set_type(expression, exp.DataType.Type.DOUBLE) 489 490 if array: 491 self._set_type( 492 expression, 493 exp.DataType( 494 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 495 ), 496 ) 497 498 return expression 499 500 def _annotate_timeunit( 501 self, expression: exp.TimeUnit | exp.DateTrunc 502 ) -> exp.TimeUnit | exp.DateTrunc: 503 self._annotate_args(expression) 504 505 if expression.this.type.this in exp.DataType.TEXT_TYPES: 506 datatype = _coerce_date_literal(expression.this, expression.unit) 507 elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: 508 datatype = _coerce_date(expression.this, expression.unit) 509 else: 510 datatype = exp.DataType.Type.UNKNOWN 511 512 self._set_type(expression, datatype) 513 return expression 514 515 def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: 516 self._annotate_args(expression) 517 518 bracket_arg = expression.expressions[0] 519 this = expression.this 520 521 if isinstance(bracket_arg, exp.Slice): 522 self._set_type(expression, this.type) 523 elif this.type.is_type(exp.DataType.Type.ARRAY): 524 contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN 525 self._set_type(expression, contained_type) 526 elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: 527 index = this.keys.index(bracket_arg) 528 value = seq_get(this.values, index) 529 value_type = value.type if value else exp.DataType.Type.UNKNOWN 530 self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) 531 else: 532 self._set_type(expression, exp.DataType.Type.UNKNOWN) 533 534 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
306 def __init__( 307 self, 308 schema: Schema, 309 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 310 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 311 binary_coercions: t.Optional[BinaryCoercions] = None, 312 ) -> None: 313 self.schema = schema 314 self.annotators = annotators or self.ANNOTATORS 315 self.coerces_to = coerces_to or self.COERCES_TO 316 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 317 318 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 319 self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ArraySize'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.RegexpLike'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.Date'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Log10'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DateToDi'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeAdd'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Day'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.Trim'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Bracket'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.NCHAR: 'NCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.SMALLINT: 'SMALLINT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.INT: 'INT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] =
{(<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function TypeAnnotator.<lambda>>}
def
annotate(self, expression: ~E) -> ~E:
327 def annotate(self, expression: E) -> E: 328 for scope in traverse_scope(expression): 329 selects = {} 330 for name, source in scope.sources.items(): 331 if not isinstance(source, Scope): 332 continue 333 if isinstance(source.expression, exp.UDTF): 334 values = [] 335 336 if isinstance(source.expression, exp.Lateral): 337 if isinstance(source.expression.this, exp.Explode): 338 values = [source.expression.this.this] 339 else: 340 values = source.expression.expressions[0].expressions 341 342 if not values: 343 continue 344 345 selects[name] = { 346 alias: column 347 for alias, column in zip( 348 source.expression.alias_column_names, 349 values, 350 ) 351 } 352 else: 353 selects[name] = { 354 select.alias_or_name: select for select in source.expression.selects 355 } 356 357 # First annotate the current scope's column references 358 for col in scope.columns: 359 if not col.table: 360 continue 361 362 source = scope.sources.get(col.table) 363 if isinstance(source, exp.Table): 364 self._set_type(col, self.schema.get_column_type(source, col)) 365 elif source and col.table in selects and col.name in selects[col.table]: 366 self._set_type(col, selects[col.table][col.name].type) 367 368 # Then (possibly) annotate the remaining expressions in the scope 369 self._maybe_annotate(scope.expression) 370 371 return self._maybe_annotate(expression) # This takes care of non-traversable expressions