sqlglot.optimizer.annotate_types
1from sqlglot import exp 2from sqlglot.helper import ensure_list, subclasses 3from sqlglot.optimizer.scope import Scope, traverse_scope 4from sqlglot.schema import ensure_schema 5 6 7def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 8 """ 9 Recursively infer & annotate types in an expression syntax tree against a schema. 10 Assumes that we've already executed the optimizer's qualify_columns step. 11 12 Example: 13 >>> import sqlglot 14 >>> schema = {"y": {"cola": "SMALLINT"}} 15 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 16 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 17 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 18 <Type.DOUBLE: 'DOUBLE'> 19 20 Args: 21 expression (sqlglot.Expression): Expression to annotate. 22 schema (dict|sqlglot.optimizer.Schema): Database schema. 23 annotators (dict): Maps expression type to corresponding annotation function. 24 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 25 Returns: 26 sqlglot.Expression: expression annotated with types 27 """ 28 29 schema = ensure_schema(schema) 30 31 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 32 33 34class TypeAnnotator: 35 ANNOTATORS = { 36 **{ 37 expr_type: lambda self, expr: self._annotate_unary(expr) 38 for expr_type in subclasses(exp.__name__, exp.Unary) 39 }, 40 **{ 41 expr_type: lambda self, expr: self._annotate_binary(expr) 42 for expr_type in subclasses(exp.__name__, exp.Binary) 43 }, 44 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 45 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 47 exp.Alias: lambda self, expr: self._annotate_unary(expr), 48 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 49 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.Literal: lambda self, expr: self._annotate_literal(expr), 51 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 52 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 53 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 54 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 55 expr, exp.DataType.Type.BIGINT 56 ), 57 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 58 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 59 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Sum: lambda self, expr: self._annotate_by_args( 61 expr, "this", "expressions", promote=True 62 ), 63 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 64 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 65 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 66 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 67 expr, exp.DataType.Type.DATETIME 68 ), 69 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 70 expr, exp.DataType.Type.TIMESTAMP 71 ), 72 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 73 expr, exp.DataType.Type.TIMESTAMP 74 ), 75 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 76 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 78 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 79 expr, exp.DataType.Type.DATETIME 80 ), 81 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 82 expr, exp.DataType.Type.DATETIME 83 ), 84 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 85 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 87 expr, exp.DataType.Type.TIMESTAMP 88 ), 89 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 90 expr, exp.DataType.Type.TIMESTAMP 91 ), 92 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 93 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 94 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 96 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 97 expr, exp.DataType.Type.DATE 98 ), 99 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 100 expr, exp.DataType.Type.VARCHAR 101 ), 102 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 103 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 104 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 105 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 106 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 107 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 108 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 109 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 110 exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 111 exp.SafeConcat: lambda self, expr: self._annotate_with_type( 112 expr, exp.DataType.Type.VARCHAR 113 ), 114 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 115 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 116 expr, exp.DataType.Type.VARCHAR 117 ), 118 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 119 expr, exp.DataType.Type.VARCHAR 120 ), 121 exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 122 exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 123 exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 124 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 125 exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), 126 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 127 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 128 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 129 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 130 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 131 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 132 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 133 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 134 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 135 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 136 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 137 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 138 expr, exp.DataType.Type.DOUBLE 139 ), 140 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 141 expr, exp.DataType.Type.BOOLEAN 142 ), 143 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 144 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 145 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 146 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 147 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 148 exp.StrToTime: lambda self, expr: self._annotate_with_type( 149 expr, exp.DataType.Type.TIMESTAMP 150 ), 151 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 152 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 153 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 154 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 155 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 156 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 157 expr, exp.DataType.Type.VARCHAR 158 ), 159 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 160 expr, exp.DataType.Type.DATE 161 ), 162 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 163 expr, exp.DataType.Type.TIMESTAMP 164 ), 165 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 166 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 167 expr, exp.DataType.Type.VARCHAR 168 ), 169 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 170 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 171 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 172 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 173 expr, exp.DataType.Type.TIMESTAMP 174 ), 175 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 176 expr, exp.DataType.Type.VARCHAR 177 ), 178 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 179 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 180 exp.VariancePop: lambda self, expr: self._annotate_with_type( 181 expr, exp.DataType.Type.DOUBLE 182 ), 183 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 184 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 185 } 186 187 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 188 COERCES_TO = { 189 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 190 exp.DataType.Type.TEXT: set(), 191 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 192 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 193 exp.DataType.Type.NCHAR: { 194 exp.DataType.Type.VARCHAR, 195 exp.DataType.Type.NVARCHAR, 196 exp.DataType.Type.TEXT, 197 }, 198 exp.DataType.Type.CHAR: { 199 exp.DataType.Type.NCHAR, 200 exp.DataType.Type.VARCHAR, 201 exp.DataType.Type.NVARCHAR, 202 exp.DataType.Type.TEXT, 203 }, 204 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 205 exp.DataType.Type.DOUBLE: set(), 206 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 207 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 208 exp.DataType.Type.BIGINT: { 209 exp.DataType.Type.DECIMAL, 210 exp.DataType.Type.FLOAT, 211 exp.DataType.Type.DOUBLE, 212 }, 213 exp.DataType.Type.INT: { 214 exp.DataType.Type.BIGINT, 215 exp.DataType.Type.DECIMAL, 216 exp.DataType.Type.FLOAT, 217 exp.DataType.Type.DOUBLE, 218 }, 219 exp.DataType.Type.SMALLINT: { 220 exp.DataType.Type.INT, 221 exp.DataType.Type.BIGINT, 222 exp.DataType.Type.DECIMAL, 223 exp.DataType.Type.FLOAT, 224 exp.DataType.Type.DOUBLE, 225 }, 226 exp.DataType.Type.TINYINT: { 227 exp.DataType.Type.SMALLINT, 228 exp.DataType.Type.INT, 229 exp.DataType.Type.BIGINT, 230 exp.DataType.Type.DECIMAL, 231 exp.DataType.Type.FLOAT, 232 exp.DataType.Type.DOUBLE, 233 }, 234 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 235 exp.DataType.Type.TIMESTAMPLTZ: set(), 236 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 237 exp.DataType.Type.TIMESTAMP: { 238 exp.DataType.Type.TIMESTAMPTZ, 239 exp.DataType.Type.TIMESTAMPLTZ, 240 }, 241 exp.DataType.Type.DATETIME: { 242 exp.DataType.Type.TIMESTAMP, 243 exp.DataType.Type.TIMESTAMPTZ, 244 exp.DataType.Type.TIMESTAMPLTZ, 245 }, 246 exp.DataType.Type.DATE: { 247 exp.DataType.Type.DATETIME, 248 exp.DataType.Type.TIMESTAMP, 249 exp.DataType.Type.TIMESTAMPTZ, 250 exp.DataType.Type.TIMESTAMPLTZ, 251 }, 252 } 253 254 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 255 256 def __init__(self, schema=None, annotators=None, coerces_to=None): 257 self.schema = schema 258 self.annotators = annotators or self.ANNOTATORS 259 self.coerces_to = coerces_to or self.COERCES_TO 260 261 def annotate(self, expression): 262 if isinstance(expression, self.TRAVERSABLES): 263 for scope in traverse_scope(expression): 264 selects = {} 265 for name, source in scope.sources.items(): 266 if not isinstance(source, Scope): 267 continue 268 if isinstance(source.expression, exp.UDTF): 269 values = [] 270 271 if isinstance(source.expression, exp.Lateral): 272 if isinstance(source.expression.this, exp.Explode): 273 values = [source.expression.this.this] 274 else: 275 values = source.expression.expressions[0].expressions 276 277 if not values: 278 continue 279 280 selects[name] = { 281 alias: column 282 for alias, column in zip( 283 source.expression.alias_column_names, 284 values, 285 ) 286 } 287 else: 288 selects[name] = { 289 select.alias_or_name: select for select in source.expression.selects 290 } 291 # First annotate the current scope's column references 292 for col in scope.columns: 293 if not col.table: 294 continue 295 296 source = scope.sources.get(col.table) 297 if isinstance(source, exp.Table): 298 col.type = self.schema.get_column_type(source, col) 299 elif source and col.table in selects and col.name in selects[col.table]: 300 col.type = selects[col.table][col.name].type 301 # Then (possibly) annotate the remaining expressions in the scope 302 self._maybe_annotate(scope.expression) 303 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 304 305 def _maybe_annotate(self, expression): 306 if expression.type: 307 return expression # We've already inferred the expression's type 308 309 annotator = self.annotators.get(expression.__class__) 310 311 return ( 312 annotator(self, expression) 313 if annotator 314 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 315 ) 316 317 def _annotate_args(self, expression): 318 for _, value in expression.iter_expressions(): 319 self._maybe_annotate(value) 320 321 return expression 322 323 def _maybe_coerce(self, type1, type2): 324 # We propagate the NULL / UNKNOWN types upwards if found 325 if isinstance(type1, exp.DataType): 326 type1 = type1.this 327 if isinstance(type2, exp.DataType): 328 type2 = type2.this 329 330 if exp.DataType.Type.NULL in (type1, type2): 331 return exp.DataType.Type.NULL 332 if exp.DataType.Type.UNKNOWN in (type1, type2): 333 return exp.DataType.Type.UNKNOWN 334 335 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 336 337 def _annotate_binary(self, expression): 338 self._annotate_args(expression) 339 340 left_type = expression.left.type.this 341 right_type = expression.right.type.this 342 343 if isinstance(expression, exp.Connector): 344 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 345 expression.type = exp.DataType.Type.NULL 346 elif exp.DataType.Type.NULL in (left_type, right_type): 347 expression.type = exp.DataType.build( 348 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 349 ) 350 else: 351 expression.type = exp.DataType.Type.BOOLEAN 352 elif isinstance(expression, exp.Predicate): 353 expression.type = exp.DataType.Type.BOOLEAN 354 else: 355 expression.type = self._maybe_coerce(left_type, right_type) 356 357 return expression 358 359 def _annotate_unary(self, expression): 360 self._annotate_args(expression) 361 362 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 363 expression.type = exp.DataType.Type.BOOLEAN 364 else: 365 expression.type = expression.this.type 366 367 return expression 368 369 def _annotate_literal(self, expression): 370 if expression.is_string: 371 expression.type = exp.DataType.Type.VARCHAR 372 elif expression.is_int: 373 expression.type = exp.DataType.Type.INT 374 else: 375 expression.type = exp.DataType.Type.DOUBLE 376 377 return expression 378 379 def _annotate_with_type(self, expression, target_type): 380 expression.type = target_type 381 return self._annotate_args(expression) 382 383 def _annotate_by_args(self, expression, *args, promote=False): 384 self._annotate_args(expression) 385 expressions = [] 386 for arg in args: 387 arg_expr = expression.args.get(arg) 388 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 389 390 last_datatype = None 391 for expr in expressions: 392 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 393 394 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 395 396 if promote: 397 if expression.type.this in exp.DataType.INTEGER_TYPES: 398 expression.type = exp.DataType.Type.BIGINT 399 elif expression.type.this in exp.DataType.FLOAT_TYPES: 400 expression.type = exp.DataType.Type.DOUBLE 401 402 return expression
def
annotate_types(expression, schema=None, annotators=None, coerces_to=None):
8def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 9 """ 10 Recursively infer & annotate types in an expression syntax tree against a schema. 11 Assumes that we've already executed the optimizer's qualify_columns step. 12 13 Example: 14 >>> import sqlglot 15 >>> schema = {"y": {"cola": "SMALLINT"}} 16 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 17 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 18 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 19 <Type.DOUBLE: 'DOUBLE'> 20 21 Args: 22 expression (sqlglot.Expression): Expression to annotate. 23 schema (dict|sqlglot.optimizer.Schema): Database schema. 24 annotators (dict): Maps expression type to corresponding annotation function. 25 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 26 Returns: 27 sqlglot.Expression: expression annotated with types 28 """ 29 30 schema = ensure_schema(schema) 31 32 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
class
TypeAnnotator:
35class TypeAnnotator: 36 ANNOTATORS = { 37 **{ 38 expr_type: lambda self, expr: self._annotate_unary(expr) 39 for expr_type in subclasses(exp.__name__, exp.Unary) 40 }, 41 **{ 42 expr_type: lambda self, expr: self._annotate_binary(expr) 43 for expr_type in subclasses(exp.__name__, exp.Binary) 44 }, 45 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 47 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 48 exp.Alias: lambda self, expr: self._annotate_unary(expr), 49 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 51 exp.Literal: lambda self, expr: self._annotate_literal(expr), 52 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 53 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 54 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 55 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 56 expr, exp.DataType.Type.BIGINT 57 ), 58 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 59 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 61 exp.Sum: lambda self, expr: self._annotate_by_args( 62 expr, "this", "expressions", promote=True 63 ), 64 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 65 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 66 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 67 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 68 expr, exp.DataType.Type.DATETIME 69 ), 70 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 71 expr, exp.DataType.Type.TIMESTAMP 72 ), 73 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 74 expr, exp.DataType.Type.TIMESTAMP 75 ), 76 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 78 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 79 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 80 expr, exp.DataType.Type.DATETIME 81 ), 82 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 83 expr, exp.DataType.Type.DATETIME 84 ), 85 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 87 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 88 expr, exp.DataType.Type.TIMESTAMP 89 ), 90 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 91 expr, exp.DataType.Type.TIMESTAMP 92 ), 93 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 94 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 96 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 97 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 98 expr, exp.DataType.Type.DATE 99 ), 100 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 101 expr, exp.DataType.Type.VARCHAR 102 ), 103 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 104 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 105 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 106 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 107 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 108 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 109 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 110 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 111 exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 112 exp.SafeConcat: lambda self, expr: self._annotate_with_type( 113 expr, exp.DataType.Type.VARCHAR 114 ), 115 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 116 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 117 expr, exp.DataType.Type.VARCHAR 118 ), 119 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 120 expr, exp.DataType.Type.VARCHAR 121 ), 122 exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 123 exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 124 exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 125 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 126 exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), 127 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 128 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 129 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 130 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 131 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 132 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 133 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 134 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 135 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 136 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 137 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 138 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 139 expr, exp.DataType.Type.DOUBLE 140 ), 141 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 142 expr, exp.DataType.Type.BOOLEAN 143 ), 144 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 145 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 146 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 147 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 148 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 149 exp.StrToTime: lambda self, expr: self._annotate_with_type( 150 expr, exp.DataType.Type.TIMESTAMP 151 ), 152 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 153 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 154 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 155 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 156 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 157 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 158 expr, exp.DataType.Type.VARCHAR 159 ), 160 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 161 expr, exp.DataType.Type.DATE 162 ), 163 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 164 expr, exp.DataType.Type.TIMESTAMP 165 ), 166 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 167 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 168 expr, exp.DataType.Type.VARCHAR 169 ), 170 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 171 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 172 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 173 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 174 expr, exp.DataType.Type.TIMESTAMP 175 ), 176 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 177 expr, exp.DataType.Type.VARCHAR 178 ), 179 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 180 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 181 exp.VariancePop: lambda self, expr: self._annotate_with_type( 182 expr, exp.DataType.Type.DOUBLE 183 ), 184 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 185 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 186 } 187 188 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 189 COERCES_TO = { 190 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 191 exp.DataType.Type.TEXT: set(), 192 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 193 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 194 exp.DataType.Type.NCHAR: { 195 exp.DataType.Type.VARCHAR, 196 exp.DataType.Type.NVARCHAR, 197 exp.DataType.Type.TEXT, 198 }, 199 exp.DataType.Type.CHAR: { 200 exp.DataType.Type.NCHAR, 201 exp.DataType.Type.VARCHAR, 202 exp.DataType.Type.NVARCHAR, 203 exp.DataType.Type.TEXT, 204 }, 205 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 206 exp.DataType.Type.DOUBLE: set(), 207 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 208 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 209 exp.DataType.Type.BIGINT: { 210 exp.DataType.Type.DECIMAL, 211 exp.DataType.Type.FLOAT, 212 exp.DataType.Type.DOUBLE, 213 }, 214 exp.DataType.Type.INT: { 215 exp.DataType.Type.BIGINT, 216 exp.DataType.Type.DECIMAL, 217 exp.DataType.Type.FLOAT, 218 exp.DataType.Type.DOUBLE, 219 }, 220 exp.DataType.Type.SMALLINT: { 221 exp.DataType.Type.INT, 222 exp.DataType.Type.BIGINT, 223 exp.DataType.Type.DECIMAL, 224 exp.DataType.Type.FLOAT, 225 exp.DataType.Type.DOUBLE, 226 }, 227 exp.DataType.Type.TINYINT: { 228 exp.DataType.Type.SMALLINT, 229 exp.DataType.Type.INT, 230 exp.DataType.Type.BIGINT, 231 exp.DataType.Type.DECIMAL, 232 exp.DataType.Type.FLOAT, 233 exp.DataType.Type.DOUBLE, 234 }, 235 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 236 exp.DataType.Type.TIMESTAMPLTZ: set(), 237 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 238 exp.DataType.Type.TIMESTAMP: { 239 exp.DataType.Type.TIMESTAMPTZ, 240 exp.DataType.Type.TIMESTAMPLTZ, 241 }, 242 exp.DataType.Type.DATETIME: { 243 exp.DataType.Type.TIMESTAMP, 244 exp.DataType.Type.TIMESTAMPTZ, 245 exp.DataType.Type.TIMESTAMPLTZ, 246 }, 247 exp.DataType.Type.DATE: { 248 exp.DataType.Type.DATETIME, 249 exp.DataType.Type.TIMESTAMP, 250 exp.DataType.Type.TIMESTAMPTZ, 251 exp.DataType.Type.TIMESTAMPLTZ, 252 }, 253 } 254 255 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 256 257 def __init__(self, schema=None, annotators=None, coerces_to=None): 258 self.schema = schema 259 self.annotators = annotators or self.ANNOTATORS 260 self.coerces_to = coerces_to or self.COERCES_TO 261 262 def annotate(self, expression): 263 if isinstance(expression, self.TRAVERSABLES): 264 for scope in traverse_scope(expression): 265 selects = {} 266 for name, source in scope.sources.items(): 267 if not isinstance(source, Scope): 268 continue 269 if isinstance(source.expression, exp.UDTF): 270 values = [] 271 272 if isinstance(source.expression, exp.Lateral): 273 if isinstance(source.expression.this, exp.Explode): 274 values = [source.expression.this.this] 275 else: 276 values = source.expression.expressions[0].expressions 277 278 if not values: 279 continue 280 281 selects[name] = { 282 alias: column 283 for alias, column in zip( 284 source.expression.alias_column_names, 285 values, 286 ) 287 } 288 else: 289 selects[name] = { 290 select.alias_or_name: select for select in source.expression.selects 291 } 292 # First annotate the current scope's column references 293 for col in scope.columns: 294 if not col.table: 295 continue 296 297 source = scope.sources.get(col.table) 298 if isinstance(source, exp.Table): 299 col.type = self.schema.get_column_type(source, col) 300 elif source and col.table in selects and col.name in selects[col.table]: 301 col.type = selects[col.table][col.name].type 302 # Then (possibly) annotate the remaining expressions in the scope 303 self._maybe_annotate(scope.expression) 304 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 305 306 def _maybe_annotate(self, expression): 307 if expression.type: 308 return expression # We've already inferred the expression's type 309 310 annotator = self.annotators.get(expression.__class__) 311 312 return ( 313 annotator(self, expression) 314 if annotator 315 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 316 ) 317 318 def _annotate_args(self, expression): 319 for _, value in expression.iter_expressions(): 320 self._maybe_annotate(value) 321 322 return expression 323 324 def _maybe_coerce(self, type1, type2): 325 # We propagate the NULL / UNKNOWN types upwards if found 326 if isinstance(type1, exp.DataType): 327 type1 = type1.this 328 if isinstance(type2, exp.DataType): 329 type2 = type2.this 330 331 if exp.DataType.Type.NULL in (type1, type2): 332 return exp.DataType.Type.NULL 333 if exp.DataType.Type.UNKNOWN in (type1, type2): 334 return exp.DataType.Type.UNKNOWN 335 336 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 337 338 def _annotate_binary(self, expression): 339 self._annotate_args(expression) 340 341 left_type = expression.left.type.this 342 right_type = expression.right.type.this 343 344 if isinstance(expression, exp.Connector): 345 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 346 expression.type = exp.DataType.Type.NULL 347 elif exp.DataType.Type.NULL in (left_type, right_type): 348 expression.type = exp.DataType.build( 349 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 350 ) 351 else: 352 expression.type = exp.DataType.Type.BOOLEAN 353 elif isinstance(expression, exp.Predicate): 354 expression.type = exp.DataType.Type.BOOLEAN 355 else: 356 expression.type = self._maybe_coerce(left_type, right_type) 357 358 return expression 359 360 def _annotate_unary(self, expression): 361 self._annotate_args(expression) 362 363 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 364 expression.type = exp.DataType.Type.BOOLEAN 365 else: 366 expression.type = expression.this.type 367 368 return expression 369 370 def _annotate_literal(self, expression): 371 if expression.is_string: 372 expression.type = exp.DataType.Type.VARCHAR 373 elif expression.is_int: 374 expression.type = exp.DataType.Type.INT 375 else: 376 expression.type = exp.DataType.Type.DOUBLE 377 378 return expression 379 380 def _annotate_with_type(self, expression, target_type): 381 expression.type = target_type 382 return self._annotate_args(expression) 383 384 def _annotate_by_args(self, expression, *args, promote=False): 385 self._annotate_args(expression) 386 expressions = [] 387 for arg in args: 388 arg_expr = expression.args.get(arg) 389 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 390 391 last_datatype = None 392 for expr in expressions: 393 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 394 395 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 396 397 if promote: 398 if expression.type.this in exp.DataType.INTEGER_TYPES: 399 expression.type = exp.DataType.Type.BIGINT 400 elif expression.type.this in exp.DataType.FLOAT_TYPES: 401 expression.type = exp.DataType.Type.DOUBLE 402 403 return expression
def
annotate(self, expression):
262 def annotate(self, expression): 263 if isinstance(expression, self.TRAVERSABLES): 264 for scope in traverse_scope(expression): 265 selects = {} 266 for name, source in scope.sources.items(): 267 if not isinstance(source, Scope): 268 continue 269 if isinstance(source.expression, exp.UDTF): 270 values = [] 271 272 if isinstance(source.expression, exp.Lateral): 273 if isinstance(source.expression.this, exp.Explode): 274 values = [source.expression.this.this] 275 else: 276 values = source.expression.expressions[0].expressions 277 278 if not values: 279 continue 280 281 selects[name] = { 282 alias: column 283 for alias, column in zip( 284 source.expression.alias_column_names, 285 values, 286 ) 287 } 288 else: 289 selects[name] = { 290 select.alias_or_name: select for select in source.expression.selects 291 } 292 # First annotate the current scope's column references 293 for col in scope.columns: 294 if not col.table: 295 continue 296 297 source = scope.sources.get(col.table) 298 if isinstance(source, exp.Table): 299 col.type = self.schema.get_column_type(source, col) 300 elif source and col.table in selects and col.name in selects[col.table]: 301 col.type = selects[col.table][col.name].type 302 # Then (possibly) annotate the remaining expressions in the scope 303 self._maybe_annotate(scope.expression) 304 return self._maybe_annotate(expression) # This takes care of non-traversable expressions