sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot._typing import E 8from sqlglot.errors import ParseError 9from sqlglot.generator import Generator 10from sqlglot.helper import flatten, seq_get 11from sqlglot.parser import Parser 12from sqlglot.time import format_time 13from sqlglot.tokens import Token, Tokenizer, TokenType 14from sqlglot.trie import new_trie 15 16B = t.TypeVar("B", bound=exp.Binary) 17 18 19class Dialects(str, Enum): 20 DIALECT = "" 21 22 BIGQUERY = "bigquery" 23 CLICKHOUSE = "clickhouse" 24 DATABRICKS = "databricks" 25 DRILL = "drill" 26 DUCKDB = "duckdb" 27 HIVE = "hive" 28 MYSQL = "mysql" 29 ORACLE = "oracle" 30 POSTGRES = "postgres" 31 PRESTO = "presto" 32 REDSHIFT = "redshift" 33 SNOWFLAKE = "snowflake" 34 SPARK = "spark" 35 SPARK2 = "spark2" 36 SQLITE = "sqlite" 37 STARROCKS = "starrocks" 38 TABLEAU = "tableau" 39 TERADATA = "teradata" 40 TRINO = "trino" 41 TSQL = "tsql" 42 43 44class _Dialect(type): 45 classes: t.Dict[str, t.Type[Dialect]] = {} 46 47 def __eq__(cls, other: t.Any) -> bool: 48 if cls is other: 49 return True 50 if isinstance(other, str): 51 return cls is cls.get(other) 52 if isinstance(other, Dialect): 53 return cls is type(other) 54 55 return False 56 57 def __hash__(cls) -> int: 58 return hash(cls.__name__.lower()) 59 60 @classmethod 61 def __getitem__(cls, key: str) -> t.Type[Dialect]: 62 return cls.classes[key] 63 64 @classmethod 65 def get( 66 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 67 ) -> t.Optional[t.Type[Dialect]]: 68 return cls.classes.get(key, default) 69 70 def __new__(cls, clsname, bases, attrs): 71 klass = super().__new__(cls, clsname, bases, attrs) 72 enum = Dialects.__members__.get(clsname.upper()) 73 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 74 75 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 76 klass.FORMAT_TRIE = ( 77 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 78 ) 79 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 80 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 81 82 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 83 klass.parser_class = getattr(klass, "Parser", Parser) 84 klass.generator_class = getattr(klass, "Generator", Generator) 85 86 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 87 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 88 klass.tokenizer_class._IDENTIFIERS.items() 89 )[0] 90 91 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 92 return next( 93 ( 94 (s, e) 95 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 96 if t == token_type 97 ), 98 (None, None), 99 ) 100 101 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 102 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 103 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 104 105 dialect_properties = { 106 **{ 107 k: v 108 for k, v in vars(klass).items() 109 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 110 }, 111 "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], 112 "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], 113 } 114 115 if enum not in ("", "bigquery"): 116 dialect_properties["SELECT_KINDS"] = () 117 118 # Pass required dialect properties to the tokenizer, parser and generator classes 119 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 120 for name, value in dialect_properties.items(): 121 if hasattr(subclass, name): 122 setattr(subclass, name, value) 123 124 if not klass.STRICT_STRING_CONCAT: 125 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 126 127 klass.generator_class.can_identify = klass.can_identify 128 129 return klass 130 131 132class Dialect(metaclass=_Dialect): 133 # Determines the base index offset for arrays 134 INDEX_OFFSET = 0 135 136 # If true unnest table aliases are considered only as column aliases 137 UNNEST_COLUMN_ONLY = False 138 139 # Determines whether or not the table alias comes after tablesample 140 ALIAS_POST_TABLESAMPLE = False 141 142 # Determines whether or not unquoted identifiers are resolved as uppercase 143 # When set to None, it means that the dialect treats all identifiers as case-insensitive 144 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 145 146 # Determines whether or not an unquoted identifier can start with a digit 147 IDENTIFIERS_CAN_START_WITH_DIGIT = False 148 149 # Determines whether or not CONCAT's arguments must be strings 150 STRICT_STRING_CONCAT = False 151 152 # Determines how function names are going to be normalized 153 NORMALIZE_FUNCTIONS: bool | str = "upper" 154 155 # Indicates the default null ordering method to use if not explicitly set 156 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 157 NULL_ORDERING = "nulls_are_small" 158 159 DATE_FORMAT = "'%Y-%m-%d'" 160 DATEINT_FORMAT = "'%Y%m%d'" 161 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 162 163 # Custom time mappings in which the key represents dialect time format 164 # and the value represents a python time format 165 TIME_MAPPING: t.Dict[str, str] = {} 166 167 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 168 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 169 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 170 FORMAT_MAPPING: t.Dict[str, str] = {} 171 172 # Columns that are auto-generated by the engine corresponding to this dialect 173 # Such columns may be excluded from SELECT * queries, for example 174 PSEUDOCOLUMNS: t.Set[str] = set() 175 176 # Autofilled 177 tokenizer_class = Tokenizer 178 parser_class = Parser 179 generator_class = Generator 180 181 # A trie of the time_mapping keys 182 TIME_TRIE: t.Dict = {} 183 FORMAT_TRIE: t.Dict = {} 184 185 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 186 INVERSE_TIME_TRIE: t.Dict = {} 187 188 def __eq__(self, other: t.Any) -> bool: 189 return type(self) == other 190 191 def __hash__(self) -> int: 192 return hash(type(self)) 193 194 @classmethod 195 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 196 if not dialect: 197 return cls 198 if isinstance(dialect, _Dialect): 199 return dialect 200 if isinstance(dialect, Dialect): 201 return dialect.__class__ 202 203 result = cls.get(dialect) 204 if not result: 205 raise ValueError(f"Unknown dialect '{dialect}'") 206 207 return result 208 209 @classmethod 210 def format_time( 211 cls, expression: t.Optional[str | exp.Expression] 212 ) -> t.Optional[exp.Expression]: 213 if isinstance(expression, str): 214 return exp.Literal.string( 215 # the time formats are quoted 216 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 217 ) 218 219 if expression and expression.is_string: 220 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 221 222 return expression 223 224 @classmethod 225 def normalize_identifier(cls, expression: E) -> E: 226 """ 227 Normalizes an unquoted identifier to either lower or upper case, thus essentially 228 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 229 they will be normalized regardless of being quoted or not. 230 """ 231 if isinstance(expression, exp.Identifier) and ( 232 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 233 ): 234 expression.set( 235 "this", 236 expression.this.upper() 237 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 238 else expression.this.lower(), 239 ) 240 241 return expression 242 243 @classmethod 244 def case_sensitive(cls, text: str) -> bool: 245 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 246 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 247 return False 248 249 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 250 return any(unsafe(char) for char in text) 251 252 @classmethod 253 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 254 """Checks if text can be identified given an identify option. 255 256 Args: 257 text: The text to check. 258 identify: 259 "always" or `True`: Always returns true. 260 "safe": True if the identifier is case-insensitive. 261 262 Returns: 263 Whether or not the given text can be identified. 264 """ 265 if identify is True or identify == "always": 266 return True 267 268 if identify == "safe": 269 return not cls.case_sensitive(text) 270 271 return False 272 273 @classmethod 274 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 275 if isinstance(expression, exp.Identifier): 276 name = expression.this 277 expression.set( 278 "quoted", 279 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 280 ) 281 282 return expression 283 284 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 285 return self.parser(**opts).parse(self.tokenize(sql), sql) 286 287 def parse_into( 288 self, expression_type: exp.IntoType, sql: str, **opts 289 ) -> t.List[t.Optional[exp.Expression]]: 290 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 291 292 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 293 return self.generator(**opts).generate(expression) 294 295 def transpile(self, sql: str, **opts) -> t.List[str]: 296 return [self.generate(expression, **opts) for expression in self.parse(sql)] 297 298 def tokenize(self, sql: str) -> t.List[Token]: 299 return self.tokenizer.tokenize(sql) 300 301 @property 302 def tokenizer(self) -> Tokenizer: 303 if not hasattr(self, "_tokenizer"): 304 self._tokenizer = self.tokenizer_class() 305 return self._tokenizer 306 307 def parser(self, **opts) -> Parser: 308 return self.parser_class(**opts) 309 310 def generator(self, **opts) -> Generator: 311 return self.generator_class(**opts) 312 313 314DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 315 316 317def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 318 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 319 320 321def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 322 if expression.args.get("accuracy"): 323 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 324 return self.func("APPROX_COUNT_DISTINCT", expression.this) 325 326 327def if_sql(self: Generator, expression: exp.If) -> str: 328 return self.func( 329 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 330 ) 331 332 333def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 334 return self.binary(expression, "->") 335 336 337def arrow_json_extract_scalar_sql( 338 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 339) -> str: 340 return self.binary(expression, "->>") 341 342 343def inline_array_sql(self: Generator, expression: exp.Array) -> str: 344 return f"[{self.expressions(expression)}]" 345 346 347def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 348 return self.like_sql( 349 exp.Like( 350 this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() 351 ) 352 ) 353 354 355def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 356 zone = self.sql(expression, "this") 357 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 358 359 360def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 361 if expression.args.get("recursive"): 362 self.unsupported("Recursive CTEs are unsupported") 363 expression.args["recursive"] = False 364 return self.with_sql(expression) 365 366 367def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 368 n = self.sql(expression, "this") 369 d = self.sql(expression, "expression") 370 return f"IF({d} <> 0, {n} / {d}, NULL)" 371 372 373def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 374 self.unsupported("TABLESAMPLE unsupported") 375 return self.sql(expression.this) 376 377 378def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 379 self.unsupported("PIVOT unsupported") 380 return "" 381 382 383def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 384 return self.cast_sql(expression) 385 386 387def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 388 self.unsupported("Properties unsupported") 389 return "" 390 391 392def no_comment_column_constraint_sql( 393 self: Generator, expression: exp.CommentColumnConstraint 394) -> str: 395 self.unsupported("CommentColumnConstraint unsupported") 396 return "" 397 398 399def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 400 self.unsupported("MAP_FROM_ENTRIES unsupported") 401 return "" 402 403 404def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 405 this = self.sql(expression, "this") 406 substr = self.sql(expression, "substr") 407 position = self.sql(expression, "position") 408 if position: 409 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 410 return f"STRPOS({this}, {substr})" 411 412 413def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 414 this = self.sql(expression, "this") 415 struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True)) 416 return f"{this}.{struct_key}" 417 418 419def var_map_sql( 420 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 421) -> str: 422 keys = expression.args["keys"] 423 values = expression.args["values"] 424 425 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 426 self.unsupported("Cannot convert array columns into map.") 427 return self.func(map_func_name, keys, values) 428 429 args = [] 430 for key, value in zip(keys.expressions, values.expressions): 431 args.append(self.sql(key)) 432 args.append(self.sql(value)) 433 434 return self.func(map_func_name, *args) 435 436 437def format_time_lambda( 438 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 439) -> t.Callable[[t.List], E]: 440 """Helper used for time expressions. 441 442 Args: 443 exp_class: the expression class to instantiate. 444 dialect: target sql dialect. 445 default: the default format, True being time. 446 447 Returns: 448 A callable that can be used to return the appropriately formatted time expression. 449 """ 450 451 def _format_time(args: t.List): 452 return exp_class( 453 this=seq_get(args, 0), 454 format=Dialect[dialect].format_time( 455 seq_get(args, 1) 456 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 457 ), 458 ) 459 460 return _format_time 461 462 463def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 464 """ 465 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 466 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 467 columns are removed from the create statement. 468 """ 469 has_schema = isinstance(expression.this, exp.Schema) 470 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 471 472 if has_schema and is_partitionable: 473 expression = expression.copy() 474 prop = expression.find(exp.PartitionedByProperty) 475 if prop and prop.this and not isinstance(prop.this, exp.Schema): 476 schema = expression.this 477 columns = {v.name.upper() for v in prop.this.expressions} 478 partitions = [col for col in schema.expressions if col.name.upper() in columns] 479 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 480 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 481 expression.set("this", schema) 482 483 return self.create_sql(expression) 484 485 486def parse_date_delta( 487 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 488) -> t.Callable[[t.List], E]: 489 def inner_func(args: t.List) -> E: 490 unit_based = len(args) == 3 491 this = args[2] if unit_based else seq_get(args, 0) 492 unit = args[0] if unit_based else exp.Literal.string("DAY") 493 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 494 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 495 496 return inner_func 497 498 499def parse_date_delta_with_interval( 500 expression_class: t.Type[E], 501) -> t.Callable[[t.List], t.Optional[E]]: 502 def func(args: t.List) -> t.Optional[E]: 503 if len(args) < 2: 504 return None 505 506 interval = args[1] 507 508 if not isinstance(interval, exp.Interval): 509 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 510 511 expression = interval.this 512 if expression and expression.is_string: 513 expression = exp.Literal.number(expression.this) 514 515 return expression_class( 516 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 517 ) 518 519 return func 520 521 522def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 523 unit = seq_get(args, 0) 524 this = seq_get(args, 1) 525 526 if isinstance(this, exp.Cast) and this.is_type("date"): 527 return exp.DateTrunc(unit=unit, this=this) 528 return exp.TimestampTrunc(this=this, unit=unit) 529 530 531def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 532 return self.func( 533 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 534 ) 535 536 537def locate_to_strposition(args: t.List) -> exp.Expression: 538 return exp.StrPosition( 539 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 540 ) 541 542 543def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 544 return self.func( 545 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 546 ) 547 548 549def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 550 expression = expression.copy() 551 return self.sql( 552 exp.Substring( 553 this=expression.this, start=exp.Literal.number(1), length=expression.expression 554 ) 555 ) 556 557 558def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 559 expression = expression.copy() 560 return self.sql( 561 exp.Substring( 562 this=expression.this, 563 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 564 ) 565 ) 566 567 568def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 569 return self.sql(exp.cast(expression.this, "timestamp")) 570 571 572def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 573 return self.sql(exp.cast(expression.this, "date")) 574 575 576# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 577def encode_decode_sql( 578 self: Generator, expression: exp.Expression, name: str, replace: bool = True 579) -> str: 580 charset = expression.args.get("charset") 581 if charset and charset.name.lower() != "utf-8": 582 self.unsupported(f"Expected utf-8 character set, got {charset}.") 583 584 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 585 586 587def min_or_least(self: Generator, expression: exp.Min) -> str: 588 name = "LEAST" if expression.expressions else "MIN" 589 return rename_func(name)(self, expression) 590 591 592def max_or_greatest(self: Generator, expression: exp.Max) -> str: 593 name = "GREATEST" if expression.expressions else "MAX" 594 return rename_func(name)(self, expression) 595 596 597def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 598 cond = expression.this 599 600 if isinstance(expression.this, exp.Distinct): 601 cond = expression.this.expressions[0] 602 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 603 604 return self.func("sum", exp.func("if", cond.copy(), 1, 0)) 605 606 607def trim_sql(self: Generator, expression: exp.Trim) -> str: 608 target = self.sql(expression, "this") 609 trim_type = self.sql(expression, "position") 610 remove_chars = self.sql(expression, "expression") 611 collation = self.sql(expression, "collation") 612 613 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 614 if not remove_chars and not collation: 615 return self.trim_sql(expression) 616 617 trim_type = f"{trim_type} " if trim_type else "" 618 remove_chars = f"{remove_chars} " if remove_chars else "" 619 from_part = "FROM " if trim_type or remove_chars else "" 620 collation = f" COLLATE {collation}" if collation else "" 621 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 622 623 624def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 625 return self.func("STRPTIME", expression.this, self.format_time(expression)) 626 627 628def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 629 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 630 _dialect = Dialect.get_or_raise(dialect) 631 time_format = self.format_time(expression) 632 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 633 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 634 635 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 636 637 return _ts_or_ds_to_date_sql 638 639 640def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 641 expression = expression.copy() 642 this, *rest_args = expression.expressions 643 for arg in rest_args: 644 this = exp.DPipe(this=this, expression=arg) 645 646 return self.sql(this) 647 648 649def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 650 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 651 if bad_args: 652 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 653 654 return self.func( 655 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 656 ) 657 658 659def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 660 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 661 if bad_args: 662 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 663 664 return self.func( 665 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 666 ) 667 668 669def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 670 names = [] 671 for agg in aggregations: 672 if isinstance(agg, exp.Alias): 673 names.append(agg.alias) 674 else: 675 """ 676 This case corresponds to aggregations without aliases being used as suffixes 677 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 678 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 679 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 680 """ 681 agg_all_unquoted = agg.transform( 682 lambda node: exp.Identifier(this=node.name, quoted=False) 683 if isinstance(node, exp.Identifier) 684 else node 685 ) 686 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 687 688 return names 689 690 691def simplify_literal(expression: E) -> E: 692 if not isinstance(expression.expression, exp.Literal): 693 from sqlglot.optimizer.simplify import simplify 694 695 simplify(expression.expression) 696 697 return expression 698 699 700def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 701 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
133class Dialect(metaclass=_Dialect): 134 # Determines the base index offset for arrays 135 INDEX_OFFSET = 0 136 137 # If true unnest table aliases are considered only as column aliases 138 UNNEST_COLUMN_ONLY = False 139 140 # Determines whether or not the table alias comes after tablesample 141 ALIAS_POST_TABLESAMPLE = False 142 143 # Determines whether or not unquoted identifiers are resolved as uppercase 144 # When set to None, it means that the dialect treats all identifiers as case-insensitive 145 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 146 147 # Determines whether or not an unquoted identifier can start with a digit 148 IDENTIFIERS_CAN_START_WITH_DIGIT = False 149 150 # Determines whether or not CONCAT's arguments must be strings 151 STRICT_STRING_CONCAT = False 152 153 # Determines how function names are going to be normalized 154 NORMALIZE_FUNCTIONS: bool | str = "upper" 155 156 # Indicates the default null ordering method to use if not explicitly set 157 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 158 NULL_ORDERING = "nulls_are_small" 159 160 DATE_FORMAT = "'%Y-%m-%d'" 161 DATEINT_FORMAT = "'%Y%m%d'" 162 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 163 164 # Custom time mappings in which the key represents dialect time format 165 # and the value represents a python time format 166 TIME_MAPPING: t.Dict[str, str] = {} 167 168 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 169 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 170 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 171 FORMAT_MAPPING: t.Dict[str, str] = {} 172 173 # Columns that are auto-generated by the engine corresponding to this dialect 174 # Such columns may be excluded from SELECT * queries, for example 175 PSEUDOCOLUMNS: t.Set[str] = set() 176 177 # Autofilled 178 tokenizer_class = Tokenizer 179 parser_class = Parser 180 generator_class = Generator 181 182 # A trie of the time_mapping keys 183 TIME_TRIE: t.Dict = {} 184 FORMAT_TRIE: t.Dict = {} 185 186 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 187 INVERSE_TIME_TRIE: t.Dict = {} 188 189 def __eq__(self, other: t.Any) -> bool: 190 return type(self) == other 191 192 def __hash__(self) -> int: 193 return hash(type(self)) 194 195 @classmethod 196 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 197 if not dialect: 198 return cls 199 if isinstance(dialect, _Dialect): 200 return dialect 201 if isinstance(dialect, Dialect): 202 return dialect.__class__ 203 204 result = cls.get(dialect) 205 if not result: 206 raise ValueError(f"Unknown dialect '{dialect}'") 207 208 return result 209 210 @classmethod 211 def format_time( 212 cls, expression: t.Optional[str | exp.Expression] 213 ) -> t.Optional[exp.Expression]: 214 if isinstance(expression, str): 215 return exp.Literal.string( 216 # the time formats are quoted 217 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 218 ) 219 220 if expression and expression.is_string: 221 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 222 223 return expression 224 225 @classmethod 226 def normalize_identifier(cls, expression: E) -> E: 227 """ 228 Normalizes an unquoted identifier to either lower or upper case, thus essentially 229 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 230 they will be normalized regardless of being quoted or not. 231 """ 232 if isinstance(expression, exp.Identifier) and ( 233 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 234 ): 235 expression.set( 236 "this", 237 expression.this.upper() 238 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 239 else expression.this.lower(), 240 ) 241 242 return expression 243 244 @classmethod 245 def case_sensitive(cls, text: str) -> bool: 246 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 247 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 248 return False 249 250 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 251 return any(unsafe(char) for char in text) 252 253 @classmethod 254 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 255 """Checks if text can be identified given an identify option. 256 257 Args: 258 text: The text to check. 259 identify: 260 "always" or `True`: Always returns true. 261 "safe": True if the identifier is case-insensitive. 262 263 Returns: 264 Whether or not the given text can be identified. 265 """ 266 if identify is True or identify == "always": 267 return True 268 269 if identify == "safe": 270 return not cls.case_sensitive(text) 271 272 return False 273 274 @classmethod 275 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 276 if isinstance(expression, exp.Identifier): 277 name = expression.this 278 expression.set( 279 "quoted", 280 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 281 ) 282 283 return expression 284 285 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 286 return self.parser(**opts).parse(self.tokenize(sql), sql) 287 288 def parse_into( 289 self, expression_type: exp.IntoType, sql: str, **opts 290 ) -> t.List[t.Optional[exp.Expression]]: 291 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 292 293 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 294 return self.generator(**opts).generate(expression) 295 296 def transpile(self, sql: str, **opts) -> t.List[str]: 297 return [self.generate(expression, **opts) for expression in self.parse(sql)] 298 299 def tokenize(self, sql: str) -> t.List[Token]: 300 return self.tokenizer.tokenize(sql) 301 302 @property 303 def tokenizer(self) -> Tokenizer: 304 if not hasattr(self, "_tokenizer"): 305 self._tokenizer = self.tokenizer_class() 306 return self._tokenizer 307 308 def parser(self, **opts) -> Parser: 309 return self.parser_class(**opts) 310 311 def generator(self, **opts) -> Generator: 312 return self.generator_class(**opts)
195 @classmethod 196 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 197 if not dialect: 198 return cls 199 if isinstance(dialect, _Dialect): 200 return dialect 201 if isinstance(dialect, Dialect): 202 return dialect.__class__ 203 204 result = cls.get(dialect) 205 if not result: 206 raise ValueError(f"Unknown dialect '{dialect}'") 207 208 return result
210 @classmethod 211 def format_time( 212 cls, expression: t.Optional[str | exp.Expression] 213 ) -> t.Optional[exp.Expression]: 214 if isinstance(expression, str): 215 return exp.Literal.string( 216 # the time formats are quoted 217 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 218 ) 219 220 if expression and expression.is_string: 221 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 222 223 return expression
225 @classmethod 226 def normalize_identifier(cls, expression: E) -> E: 227 """ 228 Normalizes an unquoted identifier to either lower or upper case, thus essentially 229 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 230 they will be normalized regardless of being quoted or not. 231 """ 232 if isinstance(expression, exp.Identifier) and ( 233 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 234 ): 235 expression.set( 236 "this", 237 expression.this.upper() 238 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 239 else expression.this.lower(), 240 ) 241 242 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
244 @classmethod 245 def case_sensitive(cls, text: str) -> bool: 246 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 247 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 248 return False 249 250 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 251 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
253 @classmethod 254 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 255 """Checks if text can be identified given an identify option. 256 257 Args: 258 text: The text to check. 259 identify: 260 "always" or `True`: Always returns true. 261 "safe": True if the identifier is case-insensitive. 262 263 Returns: 264 Whether or not the given text can be identified. 265 """ 266 if identify is True or identify == "always": 267 return True 268 269 if identify == "safe": 270 return not cls.case_sensitive(text) 271 272 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
274 @classmethod 275 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 276 if isinstance(expression, exp.Identifier): 277 name = expression.this 278 expression.set( 279 "quoted", 280 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 281 ) 282 283 return expression
405def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 406 this = self.sql(expression, "this") 407 substr = self.sql(expression, "substr") 408 position = self.sql(expression, "position") 409 if position: 410 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 411 return f"STRPOS({this}, {substr})"
420def var_map_sql( 421 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 422) -> str: 423 keys = expression.args["keys"] 424 values = expression.args["values"] 425 426 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 427 self.unsupported("Cannot convert array columns into map.") 428 return self.func(map_func_name, keys, values) 429 430 args = [] 431 for key, value in zip(keys.expressions, values.expressions): 432 args.append(self.sql(key)) 433 args.append(self.sql(value)) 434 435 return self.func(map_func_name, *args)
438def format_time_lambda( 439 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 440) -> t.Callable[[t.List], E]: 441 """Helper used for time expressions. 442 443 Args: 444 exp_class: the expression class to instantiate. 445 dialect: target sql dialect. 446 default: the default format, True being time. 447 448 Returns: 449 A callable that can be used to return the appropriately formatted time expression. 450 """ 451 452 def _format_time(args: t.List): 453 return exp_class( 454 this=seq_get(args, 0), 455 format=Dialect[dialect].format_time( 456 seq_get(args, 1) 457 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 458 ), 459 ) 460 461 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
464def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 465 """ 466 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 467 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 468 columns are removed from the create statement. 469 """ 470 has_schema = isinstance(expression.this, exp.Schema) 471 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 472 473 if has_schema and is_partitionable: 474 expression = expression.copy() 475 prop = expression.find(exp.PartitionedByProperty) 476 if prop and prop.this and not isinstance(prop.this, exp.Schema): 477 schema = expression.this 478 columns = {v.name.upper() for v in prop.this.expressions} 479 partitions = [col for col in schema.expressions if col.name.upper() in columns] 480 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 481 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 482 expression.set("this", schema) 483 484 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
487def parse_date_delta( 488 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 489) -> t.Callable[[t.List], E]: 490 def inner_func(args: t.List) -> E: 491 unit_based = len(args) == 3 492 this = args[2] if unit_based else seq_get(args, 0) 493 unit = args[0] if unit_based else exp.Literal.string("DAY") 494 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 495 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 496 497 return inner_func
500def parse_date_delta_with_interval( 501 expression_class: t.Type[E], 502) -> t.Callable[[t.List], t.Optional[E]]: 503 def func(args: t.List) -> t.Optional[E]: 504 if len(args) < 2: 505 return None 506 507 interval = args[1] 508 509 if not isinstance(interval, exp.Interval): 510 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 511 512 expression = interval.this 513 if expression and expression.is_string: 514 expression = exp.Literal.number(expression.this) 515 516 return expression_class( 517 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 518 ) 519 520 return func
578def encode_decode_sql( 579 self: Generator, expression: exp.Expression, name: str, replace: bool = True 580) -> str: 581 charset = expression.args.get("charset") 582 if charset and charset.name.lower() != "utf-8": 583 self.unsupported(f"Expected utf-8 character set, got {charset}.") 584 585 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
598def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 599 cond = expression.this 600 601 if isinstance(expression.this, exp.Distinct): 602 cond = expression.this.expressions[0] 603 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 604 605 return self.func("sum", exp.func("if", cond.copy(), 1, 0))
608def trim_sql(self: Generator, expression: exp.Trim) -> str: 609 target = self.sql(expression, "this") 610 trim_type = self.sql(expression, "position") 611 remove_chars = self.sql(expression, "expression") 612 collation = self.sql(expression, "collation") 613 614 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 615 if not remove_chars and not collation: 616 return self.trim_sql(expression) 617 618 trim_type = f"{trim_type} " if trim_type else "" 619 remove_chars = f"{remove_chars} " if remove_chars else "" 620 from_part = "FROM " if trim_type or remove_chars else "" 621 collation = f" COLLATE {collation}" if collation else "" 622 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
629def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 630 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 631 _dialect = Dialect.get_or_raise(dialect) 632 time_format = self.format_time(expression) 633 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 634 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 635 636 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 637 638 return _ts_or_ds_to_date_sql
650def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 651 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 652 if bad_args: 653 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 654 655 return self.func( 656 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 657 )
660def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 661 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 662 if bad_args: 663 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 664 665 return self.func( 666 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 667 )
670def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 671 names = [] 672 for agg in aggregations: 673 if isinstance(agg, exp.Alias): 674 names.append(agg.alias) 675 else: 676 """ 677 This case corresponds to aggregations without aliases being used as suffixes 678 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 679 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 680 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 681 """ 682 agg_all_unquoted = agg.transform( 683 lambda node: exp.Identifier(this=node.name, quoted=False) 684 if isinstance(node, exp.Identifier) 685 else node 686 ) 687 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 688 689 return names