sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import datetime 4import functools 5import typing as t 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.helper import ensure_list, subclasses 10from sqlglot.optimizer.scope import Scope, traverse_scope 11from sqlglot.schema import Schema, ensure_schema 12 13if t.TYPE_CHECKING: 14 B = t.TypeVar("B", bound=exp.Binary) 15 16 BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] 17 BinaryCoercions = t.Dict[ 18 t.Tuple[exp.DataType.Type, exp.DataType.Type], 19 BinaryCoercionFunc, 20 ] 21 22 23# Interval units that operate on date components 24DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} 25 26 27def annotate_types( 28 expression: E, 29 schema: t.Optional[t.Dict | Schema] = None, 30 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 31 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 32) -> E: 33 """ 34 Infers the types of an expression, annotating its AST accordingly. 35 36 Example: 37 >>> import sqlglot 38 >>> schema = {"y": {"cola": "SMALLINT"}} 39 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 40 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 41 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 42 <Type.DOUBLE: 'DOUBLE'> 43 44 Args: 45 expression: Expression to annotate. 46 schema: Database schema. 47 annotators: Maps expression type to corresponding annotation function. 48 coerces_to: Maps expression type to set of types that it can be coerced into. 49 50 Returns: 51 The expression annotated with types. 52 """ 53 54 schema = ensure_schema(schema) 55 56 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 57 58 59def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 60 return lambda self, e: self._annotate_with_type(e, data_type) 61 62 63def _is_iso_date(text: str) -> bool: 64 try: 65 datetime.date.fromisoformat(text) 66 return True 67 except ValueError: 68 return False 69 70 71def _is_iso_datetime(text: str) -> bool: 72 try: 73 datetime.datetime.fromisoformat(text) 74 return True 75 except ValueError: 76 return False 77 78 79def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 80 date_text = l.name 81 unit = r.text("unit").lower() 82 83 is_iso_date = _is_iso_date(date_text) 84 85 if is_iso_date and unit in DATE_UNITS: 86 l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) 87 return exp.DataType.Type.DATE 88 89 # An ISO date is also an ISO datetime, but not vice versa 90 if is_iso_date or _is_iso_datetime(date_text): 91 l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) 92 return exp.DataType.Type.DATETIME 93 94 return exp.DataType.Type.UNKNOWN 95 96 97def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 98 unit = r.text("unit").lower() 99 if unit not in DATE_UNITS: 100 return exp.DataType.Type.DATETIME 101 return l.type.this if l.type else exp.DataType.Type.UNKNOWN 102 103 104def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: 105 @functools.wraps(func) 106 def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: 107 return func(r, l) 108 109 return _swapped 110 111 112def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: 113 return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} 114 115 116class _TypeAnnotator(type): 117 def __new__(cls, clsname, bases, attrs): 118 klass = super().__new__(cls, clsname, bases, attrs) 119 120 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 121 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 122 text_precedence = ( 123 exp.DataType.Type.TEXT, 124 exp.DataType.Type.NVARCHAR, 125 exp.DataType.Type.VARCHAR, 126 exp.DataType.Type.NCHAR, 127 exp.DataType.Type.CHAR, 128 ) 129 numeric_precedence = ( 130 exp.DataType.Type.DOUBLE, 131 exp.DataType.Type.FLOAT, 132 exp.DataType.Type.DECIMAL, 133 exp.DataType.Type.BIGINT, 134 exp.DataType.Type.INT, 135 exp.DataType.Type.SMALLINT, 136 exp.DataType.Type.TINYINT, 137 ) 138 timelike_precedence = ( 139 exp.DataType.Type.TIMESTAMPLTZ, 140 exp.DataType.Type.TIMESTAMPTZ, 141 exp.DataType.Type.TIMESTAMP, 142 exp.DataType.Type.DATETIME, 143 exp.DataType.Type.DATE, 144 ) 145 146 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 147 coerces_to = set() 148 for data_type in type_precedence: 149 klass.COERCES_TO[data_type] = coerces_to.copy() 150 coerces_to |= {data_type} 151 152 return klass 153 154 155class TypeAnnotator(metaclass=_TypeAnnotator): 156 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 157 exp.DataType.Type.BIGINT: { 158 exp.ApproxDistinct, 159 exp.ArraySize, 160 exp.Count, 161 exp.Length, 162 }, 163 exp.DataType.Type.BOOLEAN: { 164 exp.Between, 165 exp.Boolean, 166 exp.In, 167 exp.RegexpLike, 168 }, 169 exp.DataType.Type.DATE: { 170 exp.CurrentDate, 171 exp.Date, 172 exp.DateFromParts, 173 exp.DateStrToDate, 174 exp.DateTrunc, 175 exp.DiToDate, 176 exp.StrToDate, 177 exp.TimeStrToDate, 178 exp.TsOrDsToDate, 179 }, 180 exp.DataType.Type.DATETIME: { 181 exp.CurrentDatetime, 182 exp.DatetimeAdd, 183 exp.DatetimeSub, 184 }, 185 exp.DataType.Type.DOUBLE: { 186 exp.ApproxQuantile, 187 exp.Avg, 188 exp.Exp, 189 exp.Ln, 190 exp.Log, 191 exp.Log2, 192 exp.Log10, 193 exp.Pow, 194 exp.Quantile, 195 exp.Round, 196 exp.SafeDivide, 197 exp.Sqrt, 198 exp.Stddev, 199 exp.StddevPop, 200 exp.StddevSamp, 201 exp.Variance, 202 exp.VariancePop, 203 }, 204 exp.DataType.Type.INT: { 205 exp.Ceil, 206 exp.DateDiff, 207 exp.DatetimeDiff, 208 exp.Extract, 209 exp.TimestampDiff, 210 exp.TimeDiff, 211 exp.DateToDi, 212 exp.Floor, 213 exp.Levenshtein, 214 exp.StrPosition, 215 exp.TsOrDiToDi, 216 }, 217 exp.DataType.Type.TIMESTAMP: { 218 exp.CurrentTime, 219 exp.CurrentTimestamp, 220 exp.StrToTime, 221 exp.TimeAdd, 222 exp.TimeStrToTime, 223 exp.TimeSub, 224 exp.Timestamp, 225 exp.TimestampAdd, 226 exp.TimestampSub, 227 exp.UnixToTime, 228 }, 229 exp.DataType.Type.TINYINT: { 230 exp.Day, 231 exp.Month, 232 exp.Week, 233 exp.Year, 234 }, 235 exp.DataType.Type.VARCHAR: { 236 exp.ArrayConcat, 237 exp.Concat, 238 exp.ConcatWs, 239 exp.DateToDateStr, 240 exp.GroupConcat, 241 exp.Initcap, 242 exp.Lower, 243 exp.SafeConcat, 244 exp.SafeDPipe, 245 exp.Substring, 246 exp.TimeToStr, 247 exp.TimeToTimeStr, 248 exp.Trim, 249 exp.TsOrDsToDateStr, 250 exp.UnixToStr, 251 exp.UnixToTimeStr, 252 exp.Upper, 253 }, 254 } 255 256 ANNOTATORS: t.Dict = { 257 **{ 258 expr_type: lambda self, e: self._annotate_unary(e) 259 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 260 }, 261 **{ 262 expr_type: lambda self, e: self._annotate_binary(e) 263 for expr_type in subclasses(exp.__name__, exp.Binary) 264 }, 265 **{ 266 expr_type: _annotate_with_type_lambda(data_type) 267 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 268 for expr_type in expressions 269 }, 270 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 271 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 272 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 273 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 274 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 275 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 276 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 277 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 278 exp.DateAdd: lambda self, e: self._annotate_dateadd(e), 279 exp.DateSub: lambda self, e: self._annotate_dateadd(e), 280 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 281 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 282 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 283 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 284 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 285 exp.Literal: lambda self, e: self._annotate_literal(e), 286 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 287 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 288 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 289 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 290 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 291 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 292 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 293 } 294 295 NESTED_TYPES = { 296 exp.DataType.Type.ARRAY, 297 } 298 299 # Specifies what types a given type can be coerced into (autofilled) 300 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 301 302 # Coercion functions for binary operations. 303 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 304 BINARY_COERCIONS: BinaryCoercions = { 305 **swap_all( 306 { 307 (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval 308 for t in exp.DataType.TEXT_TYPES 309 } 310 ), 311 **swap_all( 312 { 313 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, 314 } 315 ), 316 } 317 318 def __init__( 319 self, 320 schema: Schema, 321 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 322 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 323 binary_coercions: t.Optional[BinaryCoercions] = None, 324 ) -> None: 325 self.schema = schema 326 self.annotators = annotators or self.ANNOTATORS 327 self.coerces_to = coerces_to or self.COERCES_TO 328 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 329 330 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 331 self._visited: t.Set[int] = set() 332 333 def _set_type( 334 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 335 ) -> None: 336 expression.type = target_type # type: ignore 337 self._visited.add(id(expression)) 338 339 def annotate(self, expression: E) -> E: 340 for scope in traverse_scope(expression): 341 selects = {} 342 for name, source in scope.sources.items(): 343 if not isinstance(source, Scope): 344 continue 345 if isinstance(source.expression, exp.UDTF): 346 values = [] 347 348 if isinstance(source.expression, exp.Lateral): 349 if isinstance(source.expression.this, exp.Explode): 350 values = [source.expression.this.this] 351 else: 352 values = source.expression.expressions[0].expressions 353 354 if not values: 355 continue 356 357 selects[name] = { 358 alias: column 359 for alias, column in zip( 360 source.expression.alias_column_names, 361 values, 362 ) 363 } 364 else: 365 selects[name] = { 366 select.alias_or_name: select for select in source.expression.selects 367 } 368 369 # First annotate the current scope's column references 370 for col in scope.columns: 371 if not col.table: 372 continue 373 374 source = scope.sources.get(col.table) 375 if isinstance(source, exp.Table): 376 self._set_type(col, self.schema.get_column_type(source, col)) 377 elif source and col.table in selects and col.name in selects[col.table]: 378 self._set_type(col, selects[col.table][col.name].type) 379 380 # Then (possibly) annotate the remaining expressions in the scope 381 self._maybe_annotate(scope.expression) 382 383 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 384 385 def _maybe_annotate(self, expression: E) -> E: 386 if id(expression) in self._visited: 387 return expression # We've already inferred the expression's type 388 389 annotator = self.annotators.get(expression.__class__) 390 391 return ( 392 annotator(self, expression) 393 if annotator 394 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 395 ) 396 397 def _annotate_args(self, expression: E) -> E: 398 for _, value in expression.iter_expressions(): 399 self._maybe_annotate(value) 400 401 return expression 402 403 def _maybe_coerce( 404 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 405 ) -> exp.DataType | exp.DataType.Type: 406 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 407 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 408 409 # We propagate the NULL / UNKNOWN types upwards if found 410 if exp.DataType.Type.NULL in (type1_value, type2_value): 411 return exp.DataType.Type.NULL 412 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 413 return exp.DataType.Type.UNKNOWN 414 415 if type1_value in self.NESTED_TYPES: 416 return type1 417 if type2_value in self.NESTED_TYPES: 418 return type2 419 420 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 421 422 # Note: the following "no_type_check" decorators were added because mypy was yelling due 423 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 424 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 425 426 @t.no_type_check 427 def _annotate_binary(self, expression: B) -> B: 428 self._annotate_args(expression) 429 430 left, right = expression.left, expression.right 431 left_type, right_type = left.type.this, right.type.this 432 433 if isinstance(expression, exp.Connector): 434 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 435 self._set_type(expression, exp.DataType.Type.NULL) 436 elif exp.DataType.Type.NULL in (left_type, right_type): 437 self._set_type( 438 expression, 439 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 440 ) 441 else: 442 self._set_type(expression, exp.DataType.Type.BOOLEAN) 443 elif isinstance(expression, exp.Predicate): 444 self._set_type(expression, exp.DataType.Type.BOOLEAN) 445 elif (left_type, right_type) in self.binary_coercions: 446 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 447 else: 448 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 449 450 return expression 451 452 @t.no_type_check 453 def _annotate_unary(self, expression: E) -> E: 454 self._annotate_args(expression) 455 456 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 457 self._set_type(expression, exp.DataType.Type.BOOLEAN) 458 else: 459 self._set_type(expression, expression.this.type) 460 461 return expression 462 463 @t.no_type_check 464 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 465 if expression.is_string: 466 self._set_type(expression, exp.DataType.Type.VARCHAR) 467 elif expression.is_int: 468 self._set_type(expression, exp.DataType.Type.INT) 469 else: 470 self._set_type(expression, exp.DataType.Type.DOUBLE) 471 472 return expression 473 474 @t.no_type_check 475 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 476 self._set_type(expression, target_type) 477 return self._annotate_args(expression) 478 479 @t.no_type_check 480 def _annotate_by_args( 481 self, expression: E, *args: str, promote: bool = False, array: bool = False 482 ) -> E: 483 self._annotate_args(expression) 484 485 expressions: t.List[exp.Expression] = [] 486 for arg in args: 487 arg_expr = expression.args.get(arg) 488 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 489 490 last_datatype = None 491 for expr in expressions: 492 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 493 494 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 495 496 if promote: 497 if expression.type.this in exp.DataType.INTEGER_TYPES: 498 self._set_type(expression, exp.DataType.Type.BIGINT) 499 elif expression.type.this in exp.DataType.FLOAT_TYPES: 500 self._set_type(expression, exp.DataType.Type.DOUBLE) 501 502 if array: 503 self._set_type( 504 expression, 505 exp.DataType( 506 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 507 ), 508 ) 509 510 return expression 511 512 def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: 513 self._annotate_args(expression) 514 515 if expression.this.type.this in exp.DataType.TEXT_TYPES: 516 datatype = _coerce_literal_and_interval(expression.this, expression.interval()) 517 elif ( 518 expression.this.type.is_type(exp.DataType.Type.DATE) 519 and expression.text("unit").lower() not in DATE_UNITS 520 ): 521 datatype = exp.DataType.Type.DATETIME 522 else: 523 datatype = expression.this.type 524 525 self._set_type(expression, datatype) 526 return expression
DATE_UNITS =
{'year_month', 'quarter', 'year', 'week', 'day', 'month'}
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
28def annotate_types( 29 expression: E, 30 schema: t.Optional[t.Dict | Schema] = None, 31 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 32 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 33) -> E: 34 """ 35 Infers the types of an expression, annotating its AST accordingly. 36 37 Example: 38 >>> import sqlglot 39 >>> schema = {"y": {"cola": "SMALLINT"}} 40 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 41 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 42 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 43 <Type.DOUBLE: 'DOUBLE'> 44 45 Args: 46 expression: Expression to annotate. 47 schema: Database schema. 48 annotators: Maps expression type to corresponding annotation function. 49 coerces_to: Maps expression type to set of types that it can be coerced into. 50 51 Returns: 52 The expression annotated with types. 53 """ 54 55 schema = ensure_schema(schema) 56 57 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
def
swap_args( func: Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]) -> Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]:
def
swap_all( coercions: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]) -> Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]:
class
TypeAnnotator:
156class TypeAnnotator(metaclass=_TypeAnnotator): 157 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 158 exp.DataType.Type.BIGINT: { 159 exp.ApproxDistinct, 160 exp.ArraySize, 161 exp.Count, 162 exp.Length, 163 }, 164 exp.DataType.Type.BOOLEAN: { 165 exp.Between, 166 exp.Boolean, 167 exp.In, 168 exp.RegexpLike, 169 }, 170 exp.DataType.Type.DATE: { 171 exp.CurrentDate, 172 exp.Date, 173 exp.DateFromParts, 174 exp.DateStrToDate, 175 exp.DateTrunc, 176 exp.DiToDate, 177 exp.StrToDate, 178 exp.TimeStrToDate, 179 exp.TsOrDsToDate, 180 }, 181 exp.DataType.Type.DATETIME: { 182 exp.CurrentDatetime, 183 exp.DatetimeAdd, 184 exp.DatetimeSub, 185 }, 186 exp.DataType.Type.DOUBLE: { 187 exp.ApproxQuantile, 188 exp.Avg, 189 exp.Exp, 190 exp.Ln, 191 exp.Log, 192 exp.Log2, 193 exp.Log10, 194 exp.Pow, 195 exp.Quantile, 196 exp.Round, 197 exp.SafeDivide, 198 exp.Sqrt, 199 exp.Stddev, 200 exp.StddevPop, 201 exp.StddevSamp, 202 exp.Variance, 203 exp.VariancePop, 204 }, 205 exp.DataType.Type.INT: { 206 exp.Ceil, 207 exp.DateDiff, 208 exp.DatetimeDiff, 209 exp.Extract, 210 exp.TimestampDiff, 211 exp.TimeDiff, 212 exp.DateToDi, 213 exp.Floor, 214 exp.Levenshtein, 215 exp.StrPosition, 216 exp.TsOrDiToDi, 217 }, 218 exp.DataType.Type.TIMESTAMP: { 219 exp.CurrentTime, 220 exp.CurrentTimestamp, 221 exp.StrToTime, 222 exp.TimeAdd, 223 exp.TimeStrToTime, 224 exp.TimeSub, 225 exp.Timestamp, 226 exp.TimestampAdd, 227 exp.TimestampSub, 228 exp.UnixToTime, 229 }, 230 exp.DataType.Type.TINYINT: { 231 exp.Day, 232 exp.Month, 233 exp.Week, 234 exp.Year, 235 }, 236 exp.DataType.Type.VARCHAR: { 237 exp.ArrayConcat, 238 exp.Concat, 239 exp.ConcatWs, 240 exp.DateToDateStr, 241 exp.GroupConcat, 242 exp.Initcap, 243 exp.Lower, 244 exp.SafeConcat, 245 exp.SafeDPipe, 246 exp.Substring, 247 exp.TimeToStr, 248 exp.TimeToTimeStr, 249 exp.Trim, 250 exp.TsOrDsToDateStr, 251 exp.UnixToStr, 252 exp.UnixToTimeStr, 253 exp.Upper, 254 }, 255 } 256 257 ANNOTATORS: t.Dict = { 258 **{ 259 expr_type: lambda self, e: self._annotate_unary(e) 260 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 261 }, 262 **{ 263 expr_type: lambda self, e: self._annotate_binary(e) 264 for expr_type in subclasses(exp.__name__, exp.Binary) 265 }, 266 **{ 267 expr_type: _annotate_with_type_lambda(data_type) 268 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 269 for expr_type in expressions 270 }, 271 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 272 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 273 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 274 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 275 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 276 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 277 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 278 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 279 exp.DateAdd: lambda self, e: self._annotate_dateadd(e), 280 exp.DateSub: lambda self, e: self._annotate_dateadd(e), 281 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 282 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 283 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 284 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 285 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 286 exp.Literal: lambda self, e: self._annotate_literal(e), 287 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 288 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 289 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 290 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 291 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 292 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 293 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 294 } 295 296 NESTED_TYPES = { 297 exp.DataType.Type.ARRAY, 298 } 299 300 # Specifies what types a given type can be coerced into (autofilled) 301 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 302 303 # Coercion functions for binary operations. 304 # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. 305 BINARY_COERCIONS: BinaryCoercions = { 306 **swap_all( 307 { 308 (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval 309 for t in exp.DataType.TEXT_TYPES 310 } 311 ), 312 **swap_all( 313 { 314 (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, 315 } 316 ), 317 } 318 319 def __init__( 320 self, 321 schema: Schema, 322 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 323 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 324 binary_coercions: t.Optional[BinaryCoercions] = None, 325 ) -> None: 326 self.schema = schema 327 self.annotators = annotators or self.ANNOTATORS 328 self.coerces_to = coerces_to or self.COERCES_TO 329 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 330 331 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 332 self._visited: t.Set[int] = set() 333 334 def _set_type( 335 self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type 336 ) -> None: 337 expression.type = target_type # type: ignore 338 self._visited.add(id(expression)) 339 340 def annotate(self, expression: E) -> E: 341 for scope in traverse_scope(expression): 342 selects = {} 343 for name, source in scope.sources.items(): 344 if not isinstance(source, Scope): 345 continue 346 if isinstance(source.expression, exp.UDTF): 347 values = [] 348 349 if isinstance(source.expression, exp.Lateral): 350 if isinstance(source.expression.this, exp.Explode): 351 values = [source.expression.this.this] 352 else: 353 values = source.expression.expressions[0].expressions 354 355 if not values: 356 continue 357 358 selects[name] = { 359 alias: column 360 for alias, column in zip( 361 source.expression.alias_column_names, 362 values, 363 ) 364 } 365 else: 366 selects[name] = { 367 select.alias_or_name: select for select in source.expression.selects 368 } 369 370 # First annotate the current scope's column references 371 for col in scope.columns: 372 if not col.table: 373 continue 374 375 source = scope.sources.get(col.table) 376 if isinstance(source, exp.Table): 377 self._set_type(col, self.schema.get_column_type(source, col)) 378 elif source and col.table in selects and col.name in selects[col.table]: 379 self._set_type(col, selects[col.table][col.name].type) 380 381 # Then (possibly) annotate the remaining expressions in the scope 382 self._maybe_annotate(scope.expression) 383 384 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 385 386 def _maybe_annotate(self, expression: E) -> E: 387 if id(expression) in self._visited: 388 return expression # We've already inferred the expression's type 389 390 annotator = self.annotators.get(expression.__class__) 391 392 return ( 393 annotator(self, expression) 394 if annotator 395 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 396 ) 397 398 def _annotate_args(self, expression: E) -> E: 399 for _, value in expression.iter_expressions(): 400 self._maybe_annotate(value) 401 402 return expression 403 404 def _maybe_coerce( 405 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 406 ) -> exp.DataType | exp.DataType.Type: 407 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 408 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 409 410 # We propagate the NULL / UNKNOWN types upwards if found 411 if exp.DataType.Type.NULL in (type1_value, type2_value): 412 return exp.DataType.Type.NULL 413 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 414 return exp.DataType.Type.UNKNOWN 415 416 if type1_value in self.NESTED_TYPES: 417 return type1 418 if type2_value in self.NESTED_TYPES: 419 return type2 420 421 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 422 423 # Note: the following "no_type_check" decorators were added because mypy was yelling due 424 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 425 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 426 427 @t.no_type_check 428 def _annotate_binary(self, expression: B) -> B: 429 self._annotate_args(expression) 430 431 left, right = expression.left, expression.right 432 left_type, right_type = left.type.this, right.type.this 433 434 if isinstance(expression, exp.Connector): 435 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 436 self._set_type(expression, exp.DataType.Type.NULL) 437 elif exp.DataType.Type.NULL in (left_type, right_type): 438 self._set_type( 439 expression, 440 exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")), 441 ) 442 else: 443 self._set_type(expression, exp.DataType.Type.BOOLEAN) 444 elif isinstance(expression, exp.Predicate): 445 self._set_type(expression, exp.DataType.Type.BOOLEAN) 446 elif (left_type, right_type) in self.binary_coercions: 447 self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) 448 else: 449 self._set_type(expression, self._maybe_coerce(left_type, right_type)) 450 451 return expression 452 453 @t.no_type_check 454 def _annotate_unary(self, expression: E) -> E: 455 self._annotate_args(expression) 456 457 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 458 self._set_type(expression, exp.DataType.Type.BOOLEAN) 459 else: 460 self._set_type(expression, expression.this.type) 461 462 return expression 463 464 @t.no_type_check 465 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 466 if expression.is_string: 467 self._set_type(expression, exp.DataType.Type.VARCHAR) 468 elif expression.is_int: 469 self._set_type(expression, exp.DataType.Type.INT) 470 else: 471 self._set_type(expression, exp.DataType.Type.DOUBLE) 472 473 return expression 474 475 @t.no_type_check 476 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 477 self._set_type(expression, target_type) 478 return self._annotate_args(expression) 479 480 @t.no_type_check 481 def _annotate_by_args( 482 self, expression: E, *args: str, promote: bool = False, array: bool = False 483 ) -> E: 484 self._annotate_args(expression) 485 486 expressions: t.List[exp.Expression] = [] 487 for arg in args: 488 arg_expr = expression.args.get(arg) 489 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 490 491 last_datatype = None 492 for expr in expressions: 493 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 494 495 self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) 496 497 if promote: 498 if expression.type.this in exp.DataType.INTEGER_TYPES: 499 self._set_type(expression, exp.DataType.Type.BIGINT) 500 elif expression.type.this in exp.DataType.FLOAT_TYPES: 501 self._set_type(expression, exp.DataType.Type.DOUBLE) 502 503 if array: 504 self._set_type( 505 expression, 506 exp.DataType( 507 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 508 ), 509 ) 510 511 return expression 512 513 def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: 514 self._annotate_args(expression) 515 516 if expression.this.type.this in exp.DataType.TEXT_TYPES: 517 datatype = _coerce_literal_and_interval(expression.this, expression.interval()) 518 elif ( 519 expression.this.type.is_type(exp.DataType.Type.DATE) 520 and expression.text("unit").lower() not in DATE_UNITS 521 ): 522 datatype = exp.DataType.Type.DATETIME 523 else: 524 datatype = expression.this.type 525 526 self._set_type(expression, datatype) 527 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None, binary_coercions: Optional[Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]]] = None)
319 def __init__( 320 self, 321 schema: Schema, 322 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 323 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 324 binary_coercions: t.Optional[BinaryCoercions] = None, 325 ) -> None: 326 self.schema = schema 327 self.annotators = annotators or self.ANNOTATORS 328 self.coerces_to = coerces_to or self.COERCES_TO 329 self.binary_coercions = binary_coercions or self.BINARY_COERCIONS 330 331 # Caches the ids of annotated sub-Expressions, to ensure we only visit them once 332 self._visited: t.Set[int] = set()
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ApproxDistinct'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.Boolean'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.Date'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Log10'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.TimeDiff'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.Timestamp'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.UnixToTime'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Year'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.SafeDPipe'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.ConcatWs'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Timestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.CHAR: 'CHAR'>: {<Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.INT: 'INT'>: {<Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.INT: 'INT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.TINYINT: 'TINYINT'>: {<Type.INT: 'INT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>, <Type.SMALLINT: 'SMALLINT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}}
BINARY_COERCIONS: Dict[Tuple[sqlglot.expressions.DataType.Type, sqlglot.expressions.DataType.Type], Callable[[sqlglot.expressions.Expression, sqlglot.expressions.Expression], sqlglot.expressions.DataType.Type]] =
{(<Type.NCHAR: 'NCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.VARCHAR: 'VARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.NVARCHAR: 'NVARCHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.TEXT: 'TEXT'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.CHAR: 'CHAR'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NCHAR: 'NCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.VARCHAR: 'VARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.NVARCHAR: 'NVARCHAR'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.TEXT: 'TEXT'>): <function _coerce_literal_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.CHAR: 'CHAR'>): <function _coerce_literal_and_interval>, (<Type.DATE: 'DATE'>, <Type.INTERVAL: 'INTERVAL'>): <function _coerce_date_and_interval>, (<Type.INTERVAL: 'INTERVAL'>, <Type.DATE: 'DATE'>): <function _coerce_date_and_interval>}
def
annotate(self, expression: ~E) -> ~E:
340 def annotate(self, expression: E) -> E: 341 for scope in traverse_scope(expression): 342 selects = {} 343 for name, source in scope.sources.items(): 344 if not isinstance(source, Scope): 345 continue 346 if isinstance(source.expression, exp.UDTF): 347 values = [] 348 349 if isinstance(source.expression, exp.Lateral): 350 if isinstance(source.expression.this, exp.Explode): 351 values = [source.expression.this.this] 352 else: 353 values = source.expression.expressions[0].expressions 354 355 if not values: 356 continue 357 358 selects[name] = { 359 alias: column 360 for alias, column in zip( 361 source.expression.alias_column_names, 362 values, 363 ) 364 } 365 else: 366 selects[name] = { 367 select.alias_or_name: select for select in source.expression.selects 368 } 369 370 # First annotate the current scope's column references 371 for col in scope.columns: 372 if not col.table: 373 continue 374 375 source = scope.sources.get(col.table) 376 if isinstance(source, exp.Table): 377 self._set_type(col, self.schema.get_column_type(source, col)) 378 elif source and col.table in selects and col.name in selects[col.table]: 379 self._set_type(col, selects[col.table][col.name].type) 380 381 # Then (possibly) annotate the remaining expressions in the scope 382 self._maybe_annotate(scope.expression) 383 384 return self._maybe_annotate(expression) # This takes care of non-traversable expressions