sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import functools 4import typing as t 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.helper import ( 9 ensure_list, 10 is_date_unit, 11 is_iso_date, 12 is_iso_datetime, 13 seq_get, 14 subclasses, 15) 16from sqlglot.optimizer.scope import Scope, traverse_scope 17from sqlglot.schema import Schema, ensure_schema 18 19if t.TYPE_CHECKING: 20 B = t.TypeVar("B", bound=exp.Binary) 21 22 BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] 23 BinaryCoercions = t.Dict[ 24 t.Tuple[exp.DataType.Type, exp.DataType.Type], 25 BinaryCoercionFunc, 26 ] 27 28 29def annotate_types( 30 expression: E, 31 schema: t.Optional[t.Dict | Schema] = None, 32 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 33 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 34) -> E: 35 """ 36 Infers the types of an expression, annotating its AST accordingly. 37 38 Example: 39 >>> import sqlglot 40 >>> schema = {"y": {"cola": "SMALLINT"}} 41 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 42 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 43 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 44 <Type.DOUBLE: 'DOUBLE'> 45 46 Args: 47 expression: Expression to annotate. 48 schema: Database schema. 49 annotators: Maps expression type to corresponding annotation function. 50 coerces_to: Maps expression type to set of types that it can be coerced into. 51 52 Returns: 53 The expression annotated with types. 54 """ 55 56 schema = ensure_schema(schema) 57 58 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 59 60 61def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 62 return lambda self, e: self._annotate_with_type(e, data_type) 63 64 65def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: 66 date_text = l.name 67 is_iso_date_ = is_iso_date(date_text) 68 69 if is_iso_date_ and is_date_unit(unit): 70 return exp.DataType.Type.DATE 71 72 # An ISO date is also an ISO datetime, but not vice versa 73 if is_iso_date_ or is_iso_datetime(date_text): 74 return exp.DataType.Type.DATETIME 75 76 return exp.DataType.Type.UNKNOWN 77 78 79def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: 80 if not is_date_unit(unit): 81 return exp.DataType.Type.DATETIME 82 return l.type.this if l.type else exp.DataType.Type.UNKNOWN 83 84 85def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: 86 @functools.wraps(func) 87 def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 88 return func(r, l) 89 90 return _swapped 91 92 93def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: 94 return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} 95 96 97class _TypeAnnotator(type): 98 def __new__(cls, clsname, bases, attrs): 99 klass = super().__new__(cls, clsname, bases, attrs) 100 101 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 102 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 103 text_precedence = ( 104 exp.DataType.Type.TEXT, 105 exp.DataType.Type.NVARCHAR, 106 exp.DataType.Type.VARCHAR, 107 exp.DataType.Type.NCHAR, 108 exp.DataType.Type.CHAR, 109 ) 110 numeric_precedence = ( 111 exp.DataType.Type.DOUBLE, 112 exp.DataType.Type.FLOAT, 113 exp.DataType.Type.DECIMAL, 114 exp.DataType.Type.BIGINT, 115 exp.DataType.Type.INT, 116 exp.DataType.Type.SMALLINT, 117 exp.DataType.Type.TINYINT, 118 ) 119 timelike_precedence = ( 120 exp.DataType.Type.TIMESTAMPLTZ, 121 exp.DataType.Type.TIMESTAMPTZ, 122 exp.DataType.Type.TIMESTAMP, 123 exp.DataType.Type.DATETIME, 124 exp.DataType.Type.DATE, 125 ) 126 127 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 128 coerces_to = set() 129 for data_type in type_precedence: 130 klass.COERCES_TO[data_type] = coerces_to.copy() 131 coerces_to |= {data_type} 132 133 return klass 134 135 136class TypeAnnotator(metaclass=_TypeAnnotator): 137 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 138 exp.DataType.Type.BIGINT: { 139 exp.ApproxDistinct, 140 exp.ArraySize, 141 exp.Count, 142 exp.Length, 143 }, 144 exp.DataType.Type.BOOLEAN: { 145 exp.Between, 146 exp.Boolean, 147 exp.In, 148 exp.RegexpLike, 149 }, 150 exp.DataType.Type.DATE: { 151 exp.CurrentDate, 152 exp.Date, 153 exp.DateFromParts, 154 exp.DateStrToDate, 155 exp.DiToDate, 156 exp.StrToDate, 157 exp.TimeStrToDate, 158 exp.TsOrDsToDate, 159 }, 160 exp.DataType.Type.DATETIME: { 161 exp.CurrentDatetime, 162 exp.DatetimeAdd, 163 exp.DatetimeSub, 164 }, 165 exp.DataType.Type.DOUBLE: { 166 exp.ApproxQuantile, 167 exp.Avg, 168 exp.Div, 169 exp.Exp, 170 exp.Ln, 171 exp.Log, 172 exp.Log2, 173 exp.Log10, 174 exp.Pow, 175 exp.Quantile, 176 exp.Round, 177 exp.SafeDivide, 178 exp.Sqrt, 179 exp.Stddev, 180 exp.StddevPop, 181 exp.StddevSamp, 182 exp.Variance, 183 exp.VariancePop, 184 }, 185 exp.DataType.Type.INT: { 186 exp.Ceil, 187 exp.DatetimeDiff, 188 exp.DateDiff, 189 exp.Extract, 190 exp.TimestampDiff, 191 exp.TimeDiff, 192 exp.DateToDi, 193 exp.Floor, 194 exp.Levenshtein, 195 exp.StrPosition, 196 exp.TsOrDiToDi, 197 }, 198 exp.DataType.Type.TIMESTAMP: { 199 exp.CurrentTime, 200 exp.CurrentTimestamp, 201 exp.StrToTime, 202 exp.TimeAdd, 203 exp.TimeStrToTime, 204 exp.TimeSub, 205 exp.Timestamp, 206 exp.TimestampAdd, 207 exp.TimestampSub, 208 exp.UnixToTime, 209 }, 210 exp.DataType.Type.TINYINT: { 211 exp.Day, 212 exp.Month, 213 exp.Week, 214 exp.Year, 215 }, 216 exp.DataType.Type.VARCHAR: { 217 exp.ArrayConcat, 218 exp.Concat, 219 exp.ConcatWs, 220 exp.DateToDateStr, 221 exp.GroupConcat, 222 exp.Initcap, 223 exp.Lower, 224 exp.SafeConcat, 225 exp.SafeDPipe, 226 exp.Substring, 227 exp.TimeToStr, 228 exp.TimeToTimeStr, 229 exp.Trim, 230 exp.TsOrDsToDateStr, 231 exp.UnixToStr, 232 exp.UnixToTimeStr, 233 exp.Upper, 234 }, 235 } 236 237 ANNOTATORS: t.Dict = { 238 **{ 239 expr_type: lambda self, e: self._annotate_unary(e) 240 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 241 }, 242 **{ 243 expr_type: lambda self, e: self._annotate_binary(e) 244 for expr_type in subclasses(exp.__name__, exp.Binary) 245 }, 246 **{ 247 expr_type: _annotate_with_type_lambda(data_type) 248 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 249 for expr_type in expressions 250 }, 251 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 252 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 253 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 254 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 255 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 256 exp.Bracket: lambda self, e: self._annotate_bracket(e), 257 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 258 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 259 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 260 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 261 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 262 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 263 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 264 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 265 exp.Div: lambda self, e: self._annotate_div(e), 266 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 267 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 268 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 269 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 270 exp.Literal: lambda self, e: self._annotate_literal(e), 271 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 272 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 273 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 274 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 275 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 276 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 277 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 278 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 279 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 280 } 281 282 NESTED_TYPES = { 283 exp.DataType.Type.ARRAY, 284 } 285 286 # Specifies what types a given type can be coerced into (autofilled) 287 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 288 289 # Coercion functions for binary operations. 290 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 291 BINARY_COERCIONS: BinaryCoercions = { 292 **swap_all( 293 { 294 (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( 295 l, r.args.get("unit") 296 ) 297 for t in exp.DataType.TEXT_TYPES 298 } 299 ), 300 **swap_all( 301 { 302 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( 303 l, r.args.get("unit") 304 ), 305 } 306 ), 307 } 308 309 def __init__( 310 self, 311 schema: Schema, 312 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 313 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 314 binary_coercions: t.Optional[BinaryCoercions] = None, 315 ) -> None: 316 self.schema = schema 317 self.annotators = annotators or self.ANNOTATORS 318 self.coerces_to = coerces_to or self.COERCES_TO 319 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 320 321 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 322 self._visited: t.Set[int] = set() 323 324 def _set_type( 325 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 326 ) -> None: 327 expression.type = target_type # type: ignore 328 self._visited.add(id(expression)) 329 330 def annotate(self, expression: E) -> E: 331 for scope in traverse_scope(expression): 332 selects = {} 333 for name, source in scope.sources.items(): 334 if not isinstance(source, Scope): 335 continue 336 if isinstance(source.expression, exp.UDTF): 337 values = [] 338 339 if isinstance(source.expression, exp.Lateral): 340 if isinstance(source.expression.this, exp.Explode): 341 values = [source.expression.this.this] 342 else: 343 values = source.expression.expressions[0].expressions 344 345 if not values: 346 continue 347 348 selects[name] = { 349 alias: column 350 for alias, column in zip( 351 source.expression.alias_column_names, 352 values, 353 ) 354 } 355 else: 356 selects[name] = { 357 select.alias_or_name: select for select in source.expression.selects 358 } 359 360 # First annotate the current scope's column references 361 for col in scope.columns: 362 if not col.table: 363 continue 364 365 source = scope.sources.get(col.table) 366 if isinstance(source, exp.Table): 367 self._set_type(col, self.schema.get_column_type(source, col)) 368 elif source and col.table in selects and col.name in selects[col.table]: 369 self._set_type(col, selects[col.table][col.name].type) 370 371 # Then (possibly) annotate the remaining expressions in the scope 372 self._maybe_annotate(scope.expression) 373 374 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 375 376 def _maybe_annotate(self, expression: E) -> E: 377 if id(expression) in self._visited: 378 return expression # We've already inferred the expression's type 379 380 annotator = self.annotators.get(expression.__class__) 381 382 return ( 383 annotator(self, expression) 384 if annotator 385 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 386 ) 387 388 def _annotate_args(self, expression: E) -> E: 389 for _, value in expression.iter_expressions(): 390 self._maybe_annotate(value) 391 392 return expression 393 394 def _maybe_coerce( 395 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 396 ) -> exp.DataType | exp.DataType.Type: 397 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 398 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 399 400 # We propagate the NULL / UNKNOWN types upwards if found 401 if exp.DataType.Type.NULL in (type1_value, type2_value): 402 return exp.DataType.Type.NULL 403 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 404 return exp.DataType.Type.UNKNOWN 405 406 if type1_value in self.NESTED_TYPES: 407 return type1 408 if type2_value in self.NESTED_TYPES: 409 return type2 410 411 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 412 413 # Note: the following "no_type_check" decorators were added because mypy was yelling due 414 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 415 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 416 417 @t.no_type_check 418 def _annotate_binary(self, expression: B) -> B: 419 self._annotate_args(expression) 420 421 left, right = expression.left, expression.right 422 left_type, right_type = left.type.this, right.type.this 423 424 if isinstance(expression, exp.Connector): 425 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 426 self._set_type(expression, exp.DataType.Type.NULL) 427 elif exp.DataType.Type.NULL in (left_type, right_type): 428 self._set_type( 429 expression, 430 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 431 ) 432 else: 433 self._set_type(expression, exp.DataType.Type.BOOLEAN) 434 elif isinstance(expression, exp.Predicate): 435 self._set_type(expression, exp.DataType.Type.BOOLEAN) 436 elif (left_type, right_type) in self.binary_coercions: 437 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 438 else: 439 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 440 441 return expression 442 443 @t.no_type_check 444 def _annotate_unary(self, expression: E) -> E: 445 self._annotate_args(expression) 446 447 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 448 self._set_type(expression, exp.DataType.Type.BOOLEAN) 449 else: 450 self._set_type(expression, expression.this.type) 451 452 return expression 453 454 @t.no_type_check 455 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 456 if expression.is_string: 457 self._set_type(expression, exp.DataType.Type.VARCHAR) 458 elif expression.is_int: 459 self._set_type(expression, exp.DataType.Type.INT) 460 else: 461 self._set_type(expression, exp.DataType.Type.DOUBLE) 462 463 return expression 464 465 @t.no_type_check 466 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 467 self._set_type(expression, target_type) 468 return self._annotate_args(expression) 469 470 @t.no_type_check 471 def _annotate_by_args( 472 self, expression: E, *args: str, promote: bool = False, array: bool = False 473 ) -> E: 474 self._annotate_args(expression) 475 476 expressions: t.List[exp.Expression] = [] 477 for arg in args: 478 arg_expr = expression.args.get(arg) 479 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 480 481 last_datatype = None 482 for expr in expressions: 483 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 484 485 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 486 487 if promote: 488 if expression.type.this in exp.DataType.INTEGER_TYPES: 489 self._set_type(expression, exp.DataType.Type.BIGINT) 490 elif expression.type.this in exp.DataType.FLOAT_TYPES: 491 self._set_type(expression, exp.DataType.Type.DOUBLE) 492 493 if array: 494 self._set_type( 495 expression, 496 exp.DataType( 497 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 498 ), 499 ) 500 501 return expression 502 503 def _annotate_timeunit( 504 self, expression: exp.TimeUnit | exp.DateTrunc 505 ) -> exp.TimeUnit | exp.DateTrunc: 506 self._annotate_args(expression) 507 508 if expression.this.type.this in exp.DataType.TEXT_TYPES: 509 datatype = _coerce_date_literal(expression.this, expression.unit) 510 elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: 511 datatype = _coerce_date(expression.this, expression.unit) 512 else: 513 datatype = exp.DataType.Type.UNKNOWN 514 515 self._set_type(expression, datatype) 516 return expression 517 518 def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: 519 self._annotate_args(expression) 520 521 bracket_arg = expression.expressions[0] 522 this = expression.this 523 524 if isinstance(bracket_arg, exp.Slice): 525 self._set_type(expression, this.type) 526 elif this.type.is_type(exp.DataType.Type.ARRAY): 527 contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN 528 self._set_type(expression, contained_type) 529 elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: 530 index = this.keys.index(bracket_arg) 531 value = seq_get(this.values, index) 532 value_type = value.type if value else exp.DataType.Type.UNKNOWN 533 self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) 534 else: 535 self._set_type(expression, exp.DataType.Type.UNKNOWN) 536 537 return expression 538 539 def _annotate_div(self, expression: exp.Div) -> exp.Div: 540 self._annotate_args(expression) 541 542 left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore 543 544 if ( 545 expression.args.get("typed") 546 and left_type in exp.DataType.INTEGER_TYPES 547 and right_type in exp.DataType.INTEGER_TYPES 548 ): 549 self._set_type(expression, exp.DataType.Type.BIGINT) 550 else: 551 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 552 553 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
30def annotate_types( 31 expression: E, 32 schema: t.Optional[t.Dict | Schema] = None, 33 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 34 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 35) -> E: 36 """ 37 Infers the types of an expression, annotating its AST accordingly. 38 39 Example: 40 >>> import sqlglot 41 >>> schema = {"y": {"cola": "SMALLINT"}} 42 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 43 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 44 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 45 <Type.DOUBLE: 'DOUBLE'> 46 47 Args: 48 expression: Expression to annotate. 49 schema: Database schema. 50 annotators: Maps expression type to corresponding annotation function. 51 coerces_to: Maps expression type to set of types that it can be coerced into. 52 53 Returns: 54 The expression annotated with types. 55 """ 56 57 schema = ensure_schema(schema) 58 59 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
def
swap_args( func: Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]) -> Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]:
def
swap_all( coercions: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]) -> Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]:
class
TypeAnnotator:
137class TypeAnnotator(metaclass=_TypeAnnotator): 138 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 139 exp.DataType.Type.BIGINT: { 140 exp.ApproxDistinct, 141 exp.ArraySize, 142 exp.Count, 143 exp.Length, 144 }, 145 exp.DataType.Type.BOOLEAN: { 146 exp.Between, 147 exp.Boolean, 148 exp.In, 149 exp.RegexpLike, 150 }, 151 exp.DataType.Type.DATE: { 152 exp.CurrentDate, 153 exp.Date, 154 exp.DateFromParts, 155 exp.DateStrToDate, 156 exp.DiToDate, 157 exp.StrToDate, 158 exp.TimeStrToDate, 159 exp.TsOrDsToDate, 160 }, 161 exp.DataType.Type.DATETIME: { 162 exp.CurrentDatetime, 163 exp.DatetimeAdd, 164 exp.DatetimeSub, 165 }, 166 exp.DataType.Type.DOUBLE: { 167 exp.ApproxQuantile, 168 exp.Avg, 169 exp.Div, 170 exp.Exp, 171 exp.Ln, 172 exp.Log, 173 exp.Log2, 174 exp.Log10, 175 exp.Pow, 176 exp.Quantile, 177 exp.Round, 178 exp.SafeDivide, 179 exp.Sqrt, 180 exp.Stddev, 181 exp.StddevPop, 182 exp.StddevSamp, 183 exp.Variance, 184 exp.VariancePop, 185 }, 186 exp.DataType.Type.INT: { 187 exp.Ceil, 188 exp.DatetimeDiff, 189 exp.DateDiff, 190 exp.Extract, 191 exp.TimestampDiff, 192 exp.TimeDiff, 193 exp.DateToDi, 194 exp.Floor, 195 exp.Levenshtein, 196 exp.StrPosition, 197 exp.TsOrDiToDi, 198 }, 199 exp.DataType.Type.TIMESTAMP: { 200 exp.CurrentTime, 201 exp.CurrentTimestamp, 202 exp.StrToTime, 203 exp.TimeAdd, 204 exp.TimeStrToTime, 205 exp.TimeSub, 206 exp.Timestamp, 207 exp.TimestampAdd, 208 exp.TimestampSub, 209 exp.UnixToTime, 210 }, 211 exp.DataType.Type.TINYINT: { 212 exp.Day, 213 exp.Month, 214 exp.Week, 215 exp.Year, 216 }, 217 exp.DataType.Type.VARCHAR: { 218 exp.ArrayConcat, 219 exp.Concat, 220 exp.ConcatWs, 221 exp.DateToDateStr, 222 exp.GroupConcat, 223 exp.Initcap, 224 exp.Lower, 225 exp.SafeConcat, 226 exp.SafeDPipe, 227 exp.Substring, 228 exp.TimeToStr, 229 exp.TimeToTimeStr, 230 exp.Trim, 231 exp.TsOrDsToDateStr, 232 exp.UnixToStr, 233 exp.UnixToTimeStr, 234 exp.Upper, 235 }, 236 } 237 238 ANNOTATORS: t.Dict = { 239 **{ 240 expr_type: lambda self, e: self._annotate_unary(e) 241 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 242 }, 243 **{ 244 expr_type: lambda self, e: self._annotate_binary(e) 245 for expr_type in subclasses(exp.__name__, exp.Binary) 246 }, 247 **{ 248 expr_type: _annotate_with_type_lambda(data_type) 249 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 250 for expr_type in expressions 251 }, 252 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 253 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 254 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 255 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 256 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 257 exp.Bracket: lambda self, e: self._annotate_bracket(e), 258 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 259 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 260 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 261 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 262 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 263 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 264 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 265 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 266 exp.Div: lambda self, e: self._annotate_div(e), 267 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 268 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 269 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 270 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 271 exp.Literal: lambda self, e: self._annotate_literal(e), 272 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 273 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 274 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 275 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 276 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 277 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 278 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 279 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 280 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 281 } 282 283 NESTED_TYPES = { 284 exp.DataType.Type.ARRAY, 285 } 286 287 # Specifies what types a given type can be coerced into (autofilled) 288 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 289 290 # Coercion functions for binary operations. 291 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 292 BINARY_COERCIONS: BinaryCoercions = { 293 **swap_all( 294 { 295 (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( 296 l, r.args.get("unit") 297 ) 298 for t in exp.DataType.TEXT_TYPES 299 } 300 ), 301 **swap_all( 302 { 303 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( 304 l, r.args.get("unit") 305 ), 306 } 307 ), 308 } 309 310 def __init__( 311 self, 312 schema: Schema, 313 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 314 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 315 binary_coercions: t.Optional[BinaryCoercions] = None, 316 ) -> None: 317 self.schema = schema 318 self.annotators = annotators or self.ANNOTATORS 319 self.coerces_to = coerces_to or self.COERCES_TO 320 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 321 322 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 323 self._visited: t.Set[int] = set() 324 325 def _set_type( 326 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 327 ) -> None: 328 expression.type = target_type # type: ignore 329 self._visited.add(id(expression)) 330 331 def annotate(self, expression: E) -> E: 332 for scope in traverse_scope(expression): 333 selects = {} 334 for name, source in scope.sources.items(): 335 if not isinstance(source, Scope): 336 continue 337 if isinstance(source.expression, exp.UDTF): 338 values = [] 339 340 if isinstance(source.expression, exp.Lateral): 341 if isinstance(source.expression.this, exp.Explode): 342 values = [source.expression.this.this] 343 else: 344 values = source.expression.expressions[0].expressions 345 346 if not values: 347 continue 348 349 selects[name] = { 350 alias: column 351 for alias, column in zip( 352 source.expression.alias_column_names, 353 values, 354 ) 355 } 356 else: 357 selects[name] = { 358 select.alias_or_name: select for select in source.expression.selects 359 } 360 361 # First annotate the current scope's column references 362 for col in scope.columns: 363 if not col.table: 364 continue 365 366 source = scope.sources.get(col.table) 367 if isinstance(source, exp.Table): 368 self._set_type(col, self.schema.get_column_type(source, col)) 369 elif source and col.table in selects and col.name in selects[col.table]: 370 self._set_type(col, selects[col.table][col.name].type) 371 372 # Then (possibly) annotate the remaining expressions in the scope 373 self._maybe_annotate(scope.expression) 374 375 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 376 377 def _maybe_annotate(self, expression: E) -> E: 378 if id(expression) in self._visited: 379 return expression # We've already inferred the expression's type 380 381 annotator = self.annotators.get(expression.__class__) 382 383 return ( 384 annotator(self, expression) 385 if annotator 386 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 387 ) 388 389 def _annotate_args(self, expression: E) -> E: 390 for _, value in expression.iter_expressions(): 391 self._maybe_annotate(value) 392 393 return expression 394 395 def _maybe_coerce( 396 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 397 ) -> exp.DataType | exp.DataType.Type: 398 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 399 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 400 401 # We propagate the NULL / UNKNOWN types upwards if found 402 if exp.DataType.Type.NULL in (type1_value, type2_value): 403 return exp.DataType.Type.NULL 404 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 405 return exp.DataType.Type.UNKNOWN 406 407 if type1_value in self.NESTED_TYPES: 408 return type1 409 if type2_value in self.NESTED_TYPES: 410 return type2 411 412 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 413 414 # Note: the following "no_type_check" decorators were added because mypy was yelling due 415 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 416 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 417 418 @t.no_type_check 419 def _annotate_binary(self, expression: B) -> B: 420 self._annotate_args(expression) 421 422 left, right = expression.left, expression.right 423 left_type, right_type = left.type.this, right.type.this 424 425 if isinstance(expression, exp.Connector): 426 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 427 self._set_type(expression, exp.DataType.Type.NULL) 428 elif exp.DataType.Type.NULL in (left_type, right_type): 429 self._set_type( 430 expression, 431 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 432 ) 433 else: 434 self._set_type(expression, exp.DataType.Type.BOOLEAN) 435 elif isinstance(expression, exp.Predicate): 436 self._set_type(expression, exp.DataType.Type.BOOLEAN) 437 elif (left_type, right_type) in self.binary_coercions: 438 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 439 else: 440 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 441 442 return expression 443 444 @t.no_type_check 445 def _annotate_unary(self, expression: E) -> E: 446 self._annotate_args(expression) 447 448 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 449 self._set_type(expression, exp.DataType.Type.BOOLEAN) 450 else: 451 self._set_type(expression, expression.this.type) 452 453 return expression 454 455 @t.no_type_check 456 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 457 if expression.is_string: 458 self._set_type(expression, exp.DataType.Type.VARCHAR) 459 elif expression.is_int: 460 self._set_type(expression, exp.DataType.Type.INT) 461 else: 462 self._set_type(expression, exp.DataType.Type.DOUBLE) 463 464 return expression 465 466 @t.no_type_check 467 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 468 self._set_type(expression, target_type) 469 return self._annotate_args(expression) 470 471 @t.no_type_check 472 def _annotate_by_args( 473 self, expression: E, *args: str, promote: bool = False, array: bool = False 474 ) -> E: 475 self._annotate_args(expression) 476 477 expressions: t.List[exp.Expression] = [] 478 for arg in args: 479 arg_expr = expression.args.get(arg) 480 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 481 482 last_datatype = None 483 for expr in expressions: 484 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 485 486 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 487 488 if promote: 489 if expression.type.this in exp.DataType.INTEGER_TYPES: 490 self._set_type(expression, exp.DataType.Type.BIGINT) 491 elif expression.type.this in exp.DataType.FLOAT_TYPES: 492 self._set_type(expression, exp.DataType.Type.DOUBLE) 493 494 if array: 495 self._set_type( 496 expression, 497 exp.DataType( 498 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 499 ), 500 ) 501 502 return expression 503 504 def _annotate_timeunit( 505 self, expression: exp.TimeUnit | exp.DateTrunc 506 ) -> exp.TimeUnit | exp.DateTrunc: 507 self._annotate_args(expression) 508 509 if expression.this.type.this in exp.DataType.TEXT_TYPES: 510 datatype = _coerce_date_literal(expression.this, expression.unit) 511 elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: 512 datatype = _coerce_date(expression.this, expression.unit) 513 else: 514 datatype = exp.DataType.Type.UNKNOWN 515 516 self._set_type(expression, datatype) 517 return expression 518 519 def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: 520 self._annotate_args(expression) 521 522 bracket_arg = expression.expressions[0] 523 this = expression.this 524 525 if isinstance(bracket_arg, exp.Slice): 526 self._set_type(expression, this.type) 527 elif this.type.is_type(exp.DataType.Type.ARRAY): 528 contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN 529 self._set_type(expression, contained_type) 530 elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: 531 index = this.keys.index(bracket_arg) 532 value = seq_get(this.values, index) 533 value_type = value.type if value else exp.DataType.Type.UNKNOWN 534 self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) 535 else: 536 self._set_type(expression, exp.DataType.Type.UNKNOWN) 537 538 return expression 539 540 def _annotate_div(self, expression: exp.Div) -> exp.Div: 541 self._annotate_args(expression) 542 543 left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore 544 545 if ( 546 expression.args.get("typed") 547 and left_type in exp.DataType.INTEGER_TYPES 548 and right_type in exp.DataType.INTEGER_TYPES 549 ): 550 self._set_type(expression, exp.DataType.Type.BIGINT) 551 else: 552 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 553 554 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
310 def __init__( 311 self, 312 schema: Schema, 313 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 314 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 315 binary_coercions: t.Optional[BinaryCoercions] = None, 316 ) -> None: 317 self.schema = schema 318 self.annotators = annotators or self.ANNOTATORS 319 self.coerces_to = coerces_to or self.COERCES_TO 320 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 321 322 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 323 self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateFromParts'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Div'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.SafeDivide'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Levenshtein'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.UnixToTime'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Month'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.Upper'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.PropertyEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Abs'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Bracket'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Nullif'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.CHAR: 'CHAR'>: {<Type.NCHAR: 'NCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.INT: 'INT'>: {<Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] =
{(<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function TypeAnnotator.<dictcomp>.<lambda>>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function TypeAnnotator.<lambda>>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function TypeAnnotator.<lambda>>}
def
annotate(self, expression: ~E) -> ~E:
331 def annotate(self, expression: E) -> E: 332 for scope in traverse_scope(expression): 333 selects = {} 334 for name, source in scope.sources.items(): 335 if not isinstance(source, Scope): 336 continue 337 if isinstance(source.expression, exp.UDTF): 338 values = [] 339 340 if isinstance(source.expression, exp.Lateral): 341 if isinstance(source.expression.this, exp.Explode): 342 values = [source.expression.this.this] 343 else: 344 values = source.expression.expressions[0].expressions 345 346 if not values: 347 continue 348 349 selects[name] = { 350 alias: column 351 for alias, column in zip( 352 source.expression.alias_column_names, 353 values, 354 ) 355 } 356 else: 357 selects[name] = { 358 select.alias_or_name: select for select in source.expression.selects 359 } 360 361 # First annotate the current scope's column references 362 for col in scope.columns: 363 if not col.table: 364 continue 365 366 source = scope.sources.get(col.table) 367 if isinstance(source, exp.Table): 368 self._set_type(col, self.schema.get_column_type(source, col)) 369 elif source and col.table in selects and col.name in selects[col.table]: 370 self._set_type(col, selects[col.table][col.name].type) 371 372 # Then (possibly) annotate the remaining expressions in the scope 373 self._maybe_annotate(scope.expression) 374 375 return self._maybe_annotate(expression) # This takes care of non-traversable expressions