sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import datetime 4import functools 5import typing as t 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.helper import ensure_list, seq_get, subclasses 10from sqlglot.optimizer.scope import Scope, traverse_scope 11from sqlglot.schema import Schema, ensure_schema 12 13if t.TYPE_CHECKING: 14 B = t.TypeVar("B", bound=exp.Binary) 15 16 BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] 17 BinaryCoercions = t.Dict[ 18 t.Tuple[exp.DataType.Type, exp.DataType.Type], 19 BinaryCoercionFunc, 20 ] 21 22 23# Interval units that operate on date components 24DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} 25 26 27def annotate_types( 28 expression: E, 29 schema: t.Optional[t.Dict | Schema] = None, 30 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 31 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 32) -> E: 33 """ 34 Infers the types of an expression, annotating its AST accordingly. 35 36 Example: 37 >>> import sqlglot 38 >>> schema = {"y": {"cola": "SMALLINT"}} 39 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 40 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 41 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 42 <Type.DOUBLE: 'DOUBLE'> 43 44 Args: 45 expression: Expression to annotate. 46 schema: Database schema. 47 annotators: Maps expression type to corresponding annotation function. 48 coerces_to: Maps expression type to set of types that it can be coerced into. 49 50 Returns: 51 The expression annotated with types. 52 """ 53 54 schema = ensure_schema(schema) 55 56 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 57 58 59def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 60 return lambda self, e: self._annotate_with_type(e, data_type) 61 62 63def _is_iso_date(text: str) -> bool: 64 try: 65 datetime.date.fromisoformat(text) 66 return True 67 except ValueError: 68 return False 69 70 71def _is_iso_datetime(text: str) -> bool: 72 try: 73 datetime.datetime.fromisoformat(text) 74 return True 75 except ValueError: 76 return False 77 78 79def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 80 date_text = l.name 81 unit = r.text("unit").lower() 82 83 is_iso_date = _is_iso_date(date_text) 84 85 if is_iso_date and unit in DATE_UNITS: 86 l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) 87 return exp.DataType.Type.DATE 88 89 # An ISO date is also an ISO datetime, but not vice versa 90 if is_iso_date or _is_iso_datetime(date_text): 91 l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) 92 return exp.DataType.Type.DATETIME 93 94 return exp.DataType.Type.UNKNOWN 95 96 97def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 98 unit = r.text("unit").lower() 99 if unit not in DATE_UNITS: 100 return exp.DataType.Type.DATETIME 101 return l.type.this if l.type else exp.DataType.Type.UNKNOWN 102 103 104def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: 105 @functools.wraps(func) 106 def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 107 return func(r, l) 108 109 return _swapped 110 111 112def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: 113 return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} 114 115 116class _TypeAnnotator(type): 117 def __new__(cls, clsname, bases, attrs): 118 klass = super().__new__(cls, clsname, bases, attrs) 119 120 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 121 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 122 text_precedence = ( 123 exp.DataType.Type.TEXT, 124 exp.DataType.Type.NVARCHAR, 125 exp.DataType.Type.VARCHAR, 126 exp.DataType.Type.NCHAR, 127 exp.DataType.Type.CHAR, 128 ) 129 numeric_precedence = ( 130 exp.DataType.Type.DOUBLE, 131 exp.DataType.Type.FLOAT, 132 exp.DataType.Type.DECIMAL, 133 exp.DataType.Type.BIGINT, 134 exp.DataType.Type.INT, 135 exp.DataType.Type.SMALLINT, 136 exp.DataType.Type.TINYINT, 137 ) 138 timelike_precedence = ( 139 exp.DataType.Type.TIMESTAMPLTZ, 140 exp.DataType.Type.TIMESTAMPTZ, 141 exp.DataType.Type.TIMESTAMP, 142 exp.DataType.Type.DATETIME, 143 exp.DataType.Type.DATE, 144 ) 145 146 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 147 coerces_to = set() 148 for data_type in type_precedence: 149 klass.COERCES_TO[data_type] = coerces_to.copy() 150 coerces_to |= {data_type} 151 152 return klass 153 154 155class TypeAnnotator(metaclass=_TypeAnnotator): 156 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 157 exp.DataType.Type.BIGINT: { 158 exp.ApproxDistinct, 159 exp.ArraySize, 160 exp.Count, 161 exp.Length, 162 }, 163 exp.DataType.Type.BOOLEAN: { 164 exp.Between, 165 exp.Boolean, 166 exp.In, 167 exp.RegexpLike, 168 }, 169 exp.DataType.Type.DATE: { 170 exp.CurrentDate, 171 exp.Date, 172 exp.DateFromParts, 173 exp.DateStrToDate, 174 exp.DateTrunc, 175 exp.DiToDate, 176 exp.StrToDate, 177 exp.TimeStrToDate, 178 exp.TsOrDsToDate, 179 }, 180 exp.DataType.Type.DATETIME: { 181 exp.CurrentDatetime, 182 exp.DatetimeAdd, 183 exp.DatetimeSub, 184 }, 185 exp.DataType.Type.DOUBLE: { 186 exp.ApproxQuantile, 187 exp.Avg, 188 exp.Exp, 189 exp.Ln, 190 exp.Log, 191 exp.Log2, 192 exp.Log10, 193 exp.Pow, 194 exp.Quantile, 195 exp.Round, 196 exp.SafeDivide, 197 exp.Sqrt, 198 exp.Stddev, 199 exp.StddevPop, 200 exp.StddevSamp, 201 exp.Variance, 202 exp.VariancePop, 203 }, 204 exp.DataType.Type.INT: { 205 exp.Ceil, 206 exp.DateDiff, 207 exp.DatetimeDiff, 208 exp.Extract, 209 exp.TimestampDiff, 210 exp.TimeDiff, 211 exp.DateToDi, 212 exp.Floor, 213 exp.Levenshtein, 214 exp.StrPosition, 215 exp.TsOrDiToDi, 216 }, 217 exp.DataType.Type.TIMESTAMP: { 218 exp.CurrentTime, 219 exp.CurrentTimestamp, 220 exp.StrToTime, 221 exp.TimeAdd, 222 exp.TimeStrToTime, 223 exp.TimeSub, 224 exp.Timestamp, 225 exp.TimestampAdd, 226 exp.TimestampSub, 227 exp.UnixToTime, 228 }, 229 exp.DataType.Type.TINYINT: { 230 exp.Day, 231 exp.Month, 232 exp.Week, 233 exp.Year, 234 }, 235 exp.DataType.Type.VARCHAR: { 236 exp.ArrayConcat, 237 exp.Concat, 238 exp.ConcatWs, 239 exp.DateToDateStr, 240 exp.GroupConcat, 241 exp.Initcap, 242 exp.Lower, 243 exp.SafeConcat, 244 exp.SafeDPipe, 245 exp.Substring, 246 exp.TimeToStr, 247 exp.TimeToTimeStr, 248 exp.Trim, 249 exp.TsOrDsToDateStr, 250 exp.UnixToStr, 251 exp.UnixToTimeStr, 252 exp.Upper, 253 }, 254 } 255 256 ANNOTATORS: t.Dict = { 257 **{ 258 expr_type: lambda self, e: self._annotate_unary(e) 259 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 260 }, 261 **{ 262 expr_type: lambda self, e: self._annotate_binary(e) 263 for expr_type in subclasses(exp.__name__, exp.Binary) 264 }, 265 **{ 266 expr_type: _annotate_with_type_lambda(data_type) 267 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 268 for expr_type in expressions 269 }, 270 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 271 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 272 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 273 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 274 exp.Bracket: lambda self, e: self._annotate_bracket(e), 275 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 276 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 277 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 278 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 279 exp.DateAdd: lambda self, e: self._annotate_dateadd(e), 280 exp.DateSub: lambda self, e: self._annotate_dateadd(e), 281 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 282 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 283 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 284 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 285 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 286 exp.Literal: lambda self, e: self._annotate_literal(e), 287 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 288 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 289 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 290 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 291 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 292 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 293 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 294 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 295 } 296 297 NESTED_TYPES = { 298 exp.DataType.Type.ARRAY, 299 } 300 301 # Specifies what types a given type can be coerced into (autofilled) 302 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 303 304 # Coercion functions for binary operations. 305 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 306 BINARY_COERCIONS: BinaryCoercions = { 307 **swap_all( 308 { 309 (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval 310 for t in exp.DataType.TEXT_TYPES 311 } 312 ), 313 **swap_all( 314 { 315 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, 316 } 317 ), 318 } 319 320 def __init__( 321 self, 322 schema: Schema, 323 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 324 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 325 binary_coercions: t.Optional[BinaryCoercions] = None, 326 ) -> None: 327 self.schema = schema 328 self.annotators = annotators or self.ANNOTATORS 329 self.coerces_to = coerces_to or self.COERCES_TO 330 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 331 332 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 333 self._visited: t.Set[int] = set() 334 335 def _set_type( 336 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 337 ) -> None: 338 expression.type = target_type # type: ignore 339 self._visited.add(id(expression)) 340 341 def annotate(self, expression: E) -> E: 342 for scope in traverse_scope(expression): 343 selects = {} 344 for name, source in scope.sources.items(): 345 if not isinstance(source, Scope): 346 continue 347 if isinstance(source.expression, exp.UDTF): 348 values = [] 349 350 if isinstance(source.expression, exp.Lateral): 351 if isinstance(source.expression.this, exp.Explode): 352 values = [source.expression.this.this] 353 else: 354 values = source.expression.expressions[0].expressions 355 356 if not values: 357 continue 358 359 selects[name] = { 360 alias: column 361 for alias, column in zip( 362 source.expression.alias_column_names, 363 values, 364 ) 365 } 366 else: 367 selects[name] = { 368 select.alias_or_name: select for select in source.expression.selects 369 } 370 371 # First annotate the current scope's column references 372 for col in scope.columns: 373 if not col.table: 374 continue 375 376 source = scope.sources.get(col.table) 377 if isinstance(source, exp.Table): 378 self._set_type(col, self.schema.get_column_type(source, col)) 379 elif source and col.table in selects and col.name in selects[col.table]: 380 self._set_type(col, selects[col.table][col.name].type) 381 382 # Then (possibly) annotate the remaining expressions in the scope 383 self._maybe_annotate(scope.expression) 384 385 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 386 387 def _maybe_annotate(self, expression: E) -> E: 388 if id(expression) in self._visited: 389 return expression # We've already inferred the expression's type 390 391 annotator = self.annotators.get(expression.__class__) 392 393 return ( 394 annotator(self, expression) 395 if annotator 396 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 397 ) 398 399 def _annotate_args(self, expression: E) -> E: 400 for _, value in expression.iter_expressions(): 401 self._maybe_annotate(value) 402 403 return expression 404 405 def _maybe_coerce( 406 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 407 ) -> exp.DataType | exp.DataType.Type: 408 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 409 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 410 411 # We propagate the NULL / UNKNOWN types upwards if found 412 if exp.DataType.Type.NULL in (type1_value, type2_value): 413 return exp.DataType.Type.NULL 414 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 415 return exp.DataType.Type.UNKNOWN 416 417 if type1_value in self.NESTED_TYPES: 418 return type1 419 if type2_value in self.NESTED_TYPES: 420 return type2 421 422 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 423 424 # Note: the following "no_type_check" decorators were added because mypy was yelling due 425 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 426 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 427 428 @t.no_type_check 429 def _annotate_binary(self, expression: B) -> B: 430 self._annotate_args(expression) 431 432 left, right = expression.left, expression.right 433 left_type, right_type = left.type.this, right.type.this 434 435 if isinstance(expression, exp.Connector): 436 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 437 self._set_type(expression, exp.DataType.Type.NULL) 438 elif exp.DataType.Type.NULL in (left_type, right_type): 439 self._set_type( 440 expression, 441 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 442 ) 443 else: 444 self._set_type(expression, exp.DataType.Type.BOOLEAN) 445 elif isinstance(expression, exp.Predicate): 446 self._set_type(expression, exp.DataType.Type.BOOLEAN) 447 elif (left_type, right_type) in self.binary_coercions: 448 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 449 else: 450 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 451 452 return expression 453 454 @t.no_type_check 455 def _annotate_unary(self, expression: E) -> E: 456 self._annotate_args(expression) 457 458 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 459 self._set_type(expression, exp.DataType.Type.BOOLEAN) 460 else: 461 self._set_type(expression, expression.this.type) 462 463 return expression 464 465 @t.no_type_check 466 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 467 if expression.is_string: 468 self._set_type(expression, exp.DataType.Type.VARCHAR) 469 elif expression.is_int: 470 self._set_type(expression, exp.DataType.Type.INT) 471 else: 472 self._set_type(expression, exp.DataType.Type.DOUBLE) 473 474 return expression 475 476 @t.no_type_check 477 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 478 self._set_type(expression, target_type) 479 return self._annotate_args(expression) 480 481 @t.no_type_check 482 def _annotate_by_args( 483 self, expression: E, *args: str, promote: bool = False, array: bool = False 484 ) -> E: 485 self._annotate_args(expression) 486 487 expressions: t.List[exp.Expression] = [] 488 for arg in args: 489 arg_expr = expression.args.get(arg) 490 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 491 492 last_datatype = None 493 for expr in expressions: 494 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 495 496 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 497 498 if promote: 499 if expression.type.this in exp.DataType.INTEGER_TYPES: 500 self._set_type(expression, exp.DataType.Type.BIGINT) 501 elif expression.type.this in exp.DataType.FLOAT_TYPES: 502 self._set_type(expression, exp.DataType.Type.DOUBLE) 503 504 if array: 505 self._set_type( 506 expression, 507 exp.DataType( 508 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 509 ), 510 ) 511 512 return expression 513 514 def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: 515 self._annotate_args(expression) 516 517 if expression.this.type.this in exp.DataType.TEXT_TYPES: 518 datatype = _coerce_literal_and_interval(expression.this, expression.interval()) 519 elif ( 520 expression.this.type.is_type(exp.DataType.Type.DATE) 521 and expression.text("unit").lower() not in DATE_UNITS 522 ): 523 datatype = exp.DataType.Type.DATETIME 524 else: 525 datatype = expression.this.type 526 527 self._set_type(expression, datatype) 528 return expression 529 530 def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: 531 self._annotate_args(expression) 532 533 bracket_arg = expression.expressions[0] 534 this = expression.this 535 536 if isinstance(bracket_arg, exp.Slice): 537 self._set_type(expression, this.type) 538 elif this.type.is_type(exp.DataType.Type.ARRAY): 539 contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN 540 self._set_type(expression, contained_type) 541 elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: 542 index = this.keys.index(bracket_arg) 543 value = seq_get(this.values, index) 544 value_type = value.type if value else exp.DataType.Type.UNKNOWN 545 self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) 546 else: 547 self._set_type(expression, exp.DataType.Type.UNKNOWN) 548 549 return expression
DATE_UNITS =
{'year_month', 'quarter', 'year', 'week', 'month', 'day'}
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
28def annotate_types( 29 expression: E, 30 schema: t.Optional[t.Dict | Schema] = None, 31 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 32 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 33) -> E: 34 """ 35 Infers the types of an expression, annotating its AST accordingly. 36 37 Example: 38 >>> import sqlglot 39 >>> schema = {"y": {"cola": "SMALLINT"}} 40 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 41 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 42 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 43 <Type.DOUBLE: 'DOUBLE'> 44 45 Args: 46 expression: Expression to annotate. 47 schema: Database schema. 48 annotators: Maps expression type to corresponding annotation function. 49 coerces_to: Maps expression type to set of types that it can be coerced into. 50 51 Returns: 52 The expression annotated with types. 53 """ 54 55 schema = ensure_schema(schema) 56 57 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
def
swap_args( func: Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]) -> Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]:
def
swap_all( coercions: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]) -> Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]:
class
TypeAnnotator:
156class TypeAnnotator(metaclass=_TypeAnnotator): 157 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 158 exp.DataType.Type.BIGINT: { 159 exp.ApproxDistinct, 160 exp.ArraySize, 161 exp.Count, 162 exp.Length, 163 }, 164 exp.DataType.Type.BOOLEAN: { 165 exp.Between, 166 exp.Boolean, 167 exp.In, 168 exp.RegexpLike, 169 }, 170 exp.DataType.Type.DATE: { 171 exp.CurrentDate, 172 exp.Date, 173 exp.DateFromParts, 174 exp.DateStrToDate, 175 exp.DateTrunc, 176 exp.DiToDate, 177 exp.StrToDate, 178 exp.TimeStrToDate, 179 exp.TsOrDsToDate, 180 }, 181 exp.DataType.Type.DATETIME: { 182 exp.CurrentDatetime, 183 exp.DatetimeAdd, 184 exp.DatetimeSub, 185 }, 186 exp.DataType.Type.DOUBLE: { 187 exp.ApproxQuantile, 188 exp.Avg, 189 exp.Exp, 190 exp.Ln, 191 exp.Log, 192 exp.Log2, 193 exp.Log10, 194 exp.Pow, 195 exp.Quantile, 196 exp.Round, 197 exp.SafeDivide, 198 exp.Sqrt, 199 exp.Stddev, 200 exp.StddevPop, 201 exp.StddevSamp, 202 exp.Variance, 203 exp.VariancePop, 204 }, 205 exp.DataType.Type.INT: { 206 exp.Ceil, 207 exp.DateDiff, 208 exp.DatetimeDiff, 209 exp.Extract, 210 exp.TimestampDiff, 211 exp.TimeDiff, 212 exp.DateToDi, 213 exp.Floor, 214 exp.Levenshtein, 215 exp.StrPosition, 216 exp.TsOrDiToDi, 217 }, 218 exp.DataType.Type.TIMESTAMP: { 219 exp.CurrentTime, 220 exp.CurrentTimestamp, 221 exp.StrToTime, 222 exp.TimeAdd, 223 exp.TimeStrToTime, 224 exp.TimeSub, 225 exp.Timestamp, 226 exp.TimestampAdd, 227 exp.TimestampSub, 228 exp.UnixToTime, 229 }, 230 exp.DataType.Type.TINYINT: { 231 exp.Day, 232 exp.Month, 233 exp.Week, 234 exp.Year, 235 }, 236 exp.DataType.Type.VARCHAR: { 237 exp.ArrayConcat, 238 exp.Concat, 239 exp.ConcatWs, 240 exp.DateToDateStr, 241 exp.GroupConcat, 242 exp.Initcap, 243 exp.Lower, 244 exp.SafeConcat, 245 exp.SafeDPipe, 246 exp.Substring, 247 exp.TimeToStr, 248 exp.TimeToTimeStr, 249 exp.Trim, 250 exp.TsOrDsToDateStr, 251 exp.UnixToStr, 252 exp.UnixToTimeStr, 253 exp.Upper, 254 }, 255 } 256 257 ANNOTATORS: t.Dict = { 258 **{ 259 expr_type: lambda self, e: self._annotate_unary(e) 260 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 261 }, 262 **{ 263 expr_type: lambda self, e: self._annotate_binary(e) 264 for expr_type in subclasses(exp.__name__, exp.Binary) 265 }, 266 **{ 267 expr_type: _annotate_with_type_lambda(data_type) 268 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 269 for expr_type in expressions 270 }, 271 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 272 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 273 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 274 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 275 exp.Bracket: lambda self, e: self._annotate_bracket(e), 276 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 277 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 278 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 279 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 280 exp.DateAdd: lambda self, e: self._annotate_dateadd(e), 281 exp.DateSub: lambda self, e: self._annotate_dateadd(e), 282 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 283 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 284 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 285 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 286 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 287 exp.Literal: lambda self, e: self._annotate_literal(e), 288 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 289 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 290 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 291 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 292 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 293 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 294 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 295 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 296 } 297 298 NESTED_TYPES = { 299 exp.DataType.Type.ARRAY, 300 } 301 302 # Specifies what types a given type can be coerced into (autofilled) 303 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 304 305 # Coercion functions for binary operations. 306 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 307 BINARY_COERCIONS: BinaryCoercions = { 308 **swap_all( 309 { 310 (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval 311 for t in exp.DataType.TEXT_TYPES 312 } 313 ), 314 **swap_all( 315 { 316 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, 317 } 318 ), 319 } 320 321 def __init__( 322 self, 323 schema: Schema, 324 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 325 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 326 binary_coercions: t.Optional[BinaryCoercions] = None, 327 ) -> None: 328 self.schema = schema 329 self.annotators = annotators or self.ANNOTATORS 330 self.coerces_to = coerces_to or self.COERCES_TO 331 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 332 333 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 334 self._visited: t.Set[int] = set() 335 336 def _set_type( 337 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 338 ) -> None: 339 expression.type = target_type # type: ignore 340 self._visited.add(id(expression)) 341 342 def annotate(self, expression: E) -> E: 343 for scope in traverse_scope(expression): 344 selects = {} 345 for name, source in scope.sources.items(): 346 if not isinstance(source, Scope): 347 continue 348 if isinstance(source.expression, exp.UDTF): 349 values = [] 350 351 if isinstance(source.expression, exp.Lateral): 352 if isinstance(source.expression.this, exp.Explode): 353 values = [source.expression.this.this] 354 else: 355 values = source.expression.expressions[0].expressions 356 357 if not values: 358 continue 359 360 selects[name] = { 361 alias: column 362 for alias, column in zip( 363 source.expression.alias_column_names, 364 values, 365 ) 366 } 367 else: 368 selects[name] = { 369 select.alias_or_name: select for select in source.expression.selects 370 } 371 372 # First annotate the current scope's column references 373 for col in scope.columns: 374 if not col.table: 375 continue 376 377 source = scope.sources.get(col.table) 378 if isinstance(source, exp.Table): 379 self._set_type(col, self.schema.get_column_type(source, col)) 380 elif source and col.table in selects and col.name in selects[col.table]: 381 self._set_type(col, selects[col.table][col.name].type) 382 383 # Then (possibly) annotate the remaining expressions in the scope 384 self._maybe_annotate(scope.expression) 385 386 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 387 388 def _maybe_annotate(self, expression: E) -> E: 389 if id(expression) in self._visited: 390 return expression # We've already inferred the expression's type 391 392 annotator = self.annotators.get(expression.__class__) 393 394 return ( 395 annotator(self, expression) 396 if annotator 397 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 398 ) 399 400 def _annotate_args(self, expression: E) -> E: 401 for _, value in expression.iter_expressions(): 402 self._maybe_annotate(value) 403 404 return expression 405 406 def _maybe_coerce( 407 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 408 ) -> exp.DataType | exp.DataType.Type: 409 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 410 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 411 412 # We propagate the NULL / UNKNOWN types upwards if found 413 if exp.DataType.Type.NULL in (type1_value, type2_value): 414 return exp.DataType.Type.NULL 415 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 416 return exp.DataType.Type.UNKNOWN 417 418 if type1_value in self.NESTED_TYPES: 419 return type1 420 if type2_value in self.NESTED_TYPES: 421 return type2 422 423 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 424 425 # Note: the following "no_type_check" decorators were added because mypy was yelling due 426 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 427 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 428 429 @t.no_type_check 430 def _annotate_binary(self, expression: B) -> B: 431 self._annotate_args(expression) 432 433 left, right = expression.left, expression.right 434 left_type, right_type = left.type.this, right.type.this 435 436 if isinstance(expression, exp.Connector): 437 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 438 self._set_type(expression, exp.DataType.Type.NULL) 439 elif exp.DataType.Type.NULL in (left_type, right_type): 440 self._set_type( 441 expression, 442 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 443 ) 444 else: 445 self._set_type(expression, exp.DataType.Type.BOOLEAN) 446 elif isinstance(expression, exp.Predicate): 447 self._set_type(expression, exp.DataType.Type.BOOLEAN) 448 elif (left_type, right_type) in self.binary_coercions: 449 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 450 else: 451 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 452 453 return expression 454 455 @t.no_type_check 456 def _annotate_unary(self, expression: E) -> E: 457 self._annotate_args(expression) 458 459 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 460 self._set_type(expression, exp.DataType.Type.BOOLEAN) 461 else: 462 self._set_type(expression, expression.this.type) 463 464 return expression 465 466 @t.no_type_check 467 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 468 if expression.is_string: 469 self._set_type(expression, exp.DataType.Type.VARCHAR) 470 elif expression.is_int: 471 self._set_type(expression, exp.DataType.Type.INT) 472 else: 473 self._set_type(expression, exp.DataType.Type.DOUBLE) 474 475 return expression 476 477 @t.no_type_check 478 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 479 self._set_type(expression, target_type) 480 return self._annotate_args(expression) 481 482 @t.no_type_check 483 def _annotate_by_args( 484 self, expression: E, *args: str, promote: bool = False, array: bool = False 485 ) -> E: 486 self._annotate_args(expression) 487 488 expressions: t.List[exp.Expression] = [] 489 for arg in args: 490 arg_expr = expression.args.get(arg) 491 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 492 493 last_datatype = None 494 for expr in expressions: 495 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 496 497 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 498 499 if promote: 500 if expression.type.this in exp.DataType.INTEGER_TYPES: 501 self._set_type(expression, exp.DataType.Type.BIGINT) 502 elif expression.type.this in exp.DataType.FLOAT_TYPES: 503 self._set_type(expression, exp.DataType.Type.DOUBLE) 504 505 if array: 506 self._set_type( 507 expression, 508 exp.DataType( 509 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 510 ), 511 ) 512 513 return expression 514 515 def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: 516 self._annotate_args(expression) 517 518 if expression.this.type.this in exp.DataType.TEXT_TYPES: 519 datatype = _coerce_literal_and_interval(expression.this, expression.interval()) 520 elif ( 521 expression.this.type.is_type(exp.DataType.Type.DATE) 522 and expression.text("unit").lower() not in DATE_UNITS 523 ): 524 datatype = exp.DataType.Type.DATETIME 525 else: 526 datatype = expression.this.type 527 528 self._set_type(expression, datatype) 529 return expression 530 531 def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: 532 self._annotate_args(expression) 533 534 bracket_arg = expression.expressions[0] 535 this = expression.this 536 537 if isinstance(bracket_arg, exp.Slice): 538 self._set_type(expression, this.type) 539 elif this.type.is_type(exp.DataType.Type.ARRAY): 540 contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN 541 self._set_type(expression, contained_type) 542 elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: 543 index = this.keys.index(bracket_arg) 544 value = seq_get(this.values, index) 545 value_type = value.type if value else exp.DataType.Type.UNKNOWN 546 self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) 547 else: 548 self._set_type(expression, exp.DataType.Type.UNKNOWN) 549 550 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
321 def __init__( 322 self, 323 schema: Schema, 324 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 325 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 326 binary_coercions: t.Optional[BinaryCoercions] = None, 327 ) -> None: 328 self.schema = schema 329 self.annotators = annotators or self.ANNOTATORS 330 self.coerces_to = coerces_to or self.COERCES_TO 331 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 332 333 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 334 self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.Count'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.TsOrDsToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.Ln'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.Extract'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.Timestamp'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Month'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.SafeDPipe'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Bracket'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.NCHAR: 'NCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NCHAR: 'NCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.FLOAT: 'FLOAT'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.INT: 'INT'>: {<Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.INT: 'INT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.INT: 'INT'>, <Type.SMALLINT: 'SMALLINT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] =
{(<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function _coerce_literal_and_interval>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_date_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function _coerce_date_and_interval>}
def
annotate(self, expression: ~E) -> ~E:
342 def annotate(self, expression: E) -> E: 343 for scope in traverse_scope(expression): 344 selects = {} 345 for name, source in scope.sources.items(): 346 if not isinstance(source, Scope): 347 continue 348 if isinstance(source.expression, exp.UDTF): 349 values = [] 350 351 if isinstance(source.expression, exp.Lateral): 352 if isinstance(source.expression.this, exp.Explode): 353 values = [source.expression.this.this] 354 else: 355 values = source.expression.expressions[0].expressions 356 357 if not values: 358 continue 359 360 selects[name] = { 361 alias: column 362 for alias, column in zip( 363 source.expression.alias_column_names, 364 values, 365 ) 366 } 367 else: 368 selects[name] = { 369 select.alias_or_name: select for select in source.expression.selects 370 } 371 372 # First annotate the current scope's column references 373 for col in scope.columns: 374 if not col.table: 375 continue 376 377 source = scope.sources.get(col.table) 378 if isinstance(source, exp.Table): 379 self._set_type(col, self.schema.get_column_type(source, col)) 380 elif source and col.table in selects and col.name in selects[col.table]: 381 self._set_type(col, selects[col.table][col.name].type) 382 383 # Then (possibly) annotate the remaining expressions in the scope 384 self._maybe_annotate(scope.expression) 385 386 return self._maybe_annotate(expression) # This takes care of non-traversable expressions