sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum, auto 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 21 22 23class Dialects(str, Enum): 24 """Dialects supported by SQLGLot.""" 25 26 DIALECT = "" 27 28 BIGQUERY = "bigquery" 29 CLICKHOUSE = "clickhouse" 30 DATABRICKS = "databricks" 31 DORIS = "doris" 32 DRILL = "drill" 33 DUCKDB = "duckdb" 34 HIVE = "hive" 35 MYSQL = "mysql" 36 ORACLE = "oracle" 37 POSTGRES = "postgres" 38 PRESTO = "presto" 39 REDSHIFT = "redshift" 40 SNOWFLAKE = "snowflake" 41 SPARK = "spark" 42 SPARK2 = "spark2" 43 SQLITE = "sqlite" 44 STARROCKS = "starrocks" 45 TABLEAU = "tableau" 46 TERADATA = "teradata" 47 TRINO = "trino" 48 TSQL = "tsql" 49 50 51class NormalizationStrategy(str, AutoName): 52 """Specifies the strategy according to which identifiers should be normalized.""" 53 54 LOWERCASE = auto() 55 """Unquoted identifiers are lowercased.""" 56 57 UPPERCASE = auto() 58 """Unquoted identifiers are uppercased.""" 59 60 CASE_SENSITIVE = auto() 61 """Always case-sensitive, regardless of quotes.""" 62 63 CASE_INSENSITIVE = auto() 64 """Always case-insensitive, regardless of quotes.""" 65 66 67class _Dialect(type): 68 classes: t.Dict[str, t.Type[Dialect]] = {} 69 70 def __eq__(cls, other: t.Any) -> bool: 71 if cls is other: 72 return True 73 if isinstance(other, str): 74 return cls is cls.get(other) 75 if isinstance(other, Dialect): 76 return cls is type(other) 77 78 return False 79 80 def __hash__(cls) -> int: 81 return hash(cls.__name__.lower()) 82 83 @classmethod 84 def __getitem__(cls, key: str) -> t.Type[Dialect]: 85 return cls.classes[key] 86 87 @classmethod 88 def get( 89 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 90 ) -> t.Optional[t.Type[Dialect]]: 91 return cls.classes.get(key, default) 92 93 def __new__(cls, clsname, bases, attrs): 94 klass = super().__new__(cls, clsname, bases, attrs) 95 enum = Dialects.__members__.get(clsname.upper()) 96 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 97 98 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 99 klass.FORMAT_TRIE = ( 100 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 101 ) 102 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 103 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 104 105 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 106 107 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 108 klass.parser_class = getattr(klass, "Parser", Parser) 109 klass.generator_class = getattr(klass, "Generator", Generator) 110 111 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 112 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 113 klass.tokenizer_class._IDENTIFIERS.items() 114 )[0] 115 116 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 117 return next( 118 ( 119 (s, e) 120 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 121 if t == token_type 122 ), 123 (None, None), 124 ) 125 126 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 127 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 128 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 129 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 130 131 if enum not in ("", "bigquery"): 132 klass.generator_class.SELECT_KINDS = () 133 134 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 135 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 136 TokenType.ANTI, 137 TokenType.SEMI, 138 } 139 140 return klass 141 142 143class Dialect(metaclass=_Dialect): 144 INDEX_OFFSET = 0 145 """Determines the base index offset for arrays.""" 146 147 WEEK_OFFSET = 0 148 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 149 150 UNNEST_COLUMN_ONLY = False 151 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 152 153 ALIAS_POST_TABLESAMPLE = False 154 """Determines whether or not the table alias comes after tablesample.""" 155 156 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 157 """Specifies the strategy according to which identifiers should be normalized.""" 158 159 IDENTIFIERS_CAN_START_WITH_DIGIT = False 160 """Determines whether or not an unquoted identifier can start with a digit.""" 161 162 DPIPE_IS_STRING_CONCAT = True 163 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 164 165 STRICT_STRING_CONCAT = False 166 """Determines whether or not `CONCAT`'s arguments must be strings.""" 167 168 SUPPORTS_USER_DEFINED_TYPES = True 169 """Determines whether or not user-defined data types are supported.""" 170 171 SUPPORTS_SEMI_ANTI_JOIN = True 172 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 173 174 NORMALIZE_FUNCTIONS: bool | str = "upper" 175 """Determines how function names are going to be normalized.""" 176 177 LOG_BASE_FIRST = True 178 """Determines whether the base comes first in the `LOG` function.""" 179 180 NULL_ORDERING = "nulls_are_small" 181 """ 182 Indicates the default `NULL` ordering method to use if not explicitly set. 183 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 184 """ 185 186 TYPED_DIVISION = False 187 """ 188 Whether the behavior of `a / b` depends on the types of `a` and `b`. 189 False means `a / b` is always float division. 190 True means `a / b` is integer division if both `a` and `b` are integers. 191 """ 192 193 SAFE_DIVISION = False 194 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 195 196 CONCAT_COALESCE = False 197 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 198 199 DATE_FORMAT = "'%Y-%m-%d'" 200 DATEINT_FORMAT = "'%Y%m%d'" 201 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 202 203 TIME_MAPPING: t.Dict[str, str] = {} 204 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 205 206 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 207 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 208 FORMAT_MAPPING: t.Dict[str, str] = {} 209 """ 210 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 211 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 212 """ 213 214 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 215 """Mapping of an unescaped escape sequence to the corresponding character.""" 216 217 PSEUDOCOLUMNS: t.Set[str] = set() 218 """ 219 Columns that are auto-generated by the engine corresponding to this dialect. 220 For example, such columns may be excluded from `SELECT *` queries. 221 """ 222 223 PREFER_CTE_ALIAS_COLUMN = False 224 """ 225 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 226 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 227 any projection aliases in the subquery. 228 229 For example, 230 WITH y(c) AS ( 231 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 232 ) SELECT c FROM y; 233 234 will be rewritten as 235 236 WITH y(c) AS ( 237 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 238 ) SELECT c FROM y; 239 """ 240 241 # --- Autofilled --- 242 243 tokenizer_class = Tokenizer 244 parser_class = Parser 245 generator_class = Generator 246 247 # A trie of the time_mapping keys 248 TIME_TRIE: t.Dict = {} 249 FORMAT_TRIE: t.Dict = {} 250 251 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 252 INVERSE_TIME_TRIE: t.Dict = {} 253 254 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 255 256 # Delimiters for quotes, identifiers and the corresponding escape characters 257 QUOTE_START = "'" 258 QUOTE_END = "'" 259 IDENTIFIER_START = '"' 260 IDENTIFIER_END = '"' 261 262 # Delimiters for bit, hex, byte and unicode literals 263 BIT_START: t.Optional[str] = None 264 BIT_END: t.Optional[str] = None 265 HEX_START: t.Optional[str] = None 266 HEX_END: t.Optional[str] = None 267 BYTE_START: t.Optional[str] = None 268 BYTE_END: t.Optional[str] = None 269 UNICODE_START: t.Optional[str] = None 270 UNICODE_END: t.Optional[str] = None 271 272 @classmethod 273 def get_or_raise(cls, dialect: DialectType) -> Dialect: 274 """ 275 Look up a dialect in the global dialect registry and return it if it exists. 276 277 Args: 278 dialect: The target dialect. If this is a string, it can be optionally followed by 279 additional key-value pairs that are separated by commas and are used to specify 280 dialect settings, such as whether the dialect's identifiers are case-sensitive. 281 282 Example: 283 >>> dialect = dialect_class = get_or_raise("duckdb") 284 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 285 286 Returns: 287 The corresponding Dialect instance. 288 """ 289 290 if not dialect: 291 return cls() 292 if isinstance(dialect, _Dialect): 293 return dialect() 294 if isinstance(dialect, Dialect): 295 return dialect 296 if isinstance(dialect, str): 297 try: 298 dialect_name, *kv_pairs = dialect.split(",") 299 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 300 except ValueError: 301 raise ValueError( 302 f"Invalid dialect format: '{dialect}'. " 303 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 304 ) 305 306 result = cls.get(dialect_name.strip()) 307 if not result: 308 raise ValueError(f"Unknown dialect '{dialect_name}'.") 309 310 return result(**kwargs) 311 312 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 313 314 @classmethod 315 def format_time( 316 cls, expression: t.Optional[str | exp.Expression] 317 ) -> t.Optional[exp.Expression]: 318 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 319 if isinstance(expression, str): 320 return exp.Literal.string( 321 # the time formats are quoted 322 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 323 ) 324 325 if expression and expression.is_string: 326 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 327 328 return expression 329 330 def __init__(self, **kwargs) -> None: 331 normalization_strategy = kwargs.get("normalization_strategy") 332 333 if normalization_strategy is None: 334 self.normalization_strategy = self.NORMALIZATION_STRATEGY 335 else: 336 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 337 338 def __eq__(self, other: t.Any) -> bool: 339 # Does not currently take dialect state into account 340 return type(self) == other 341 342 def __hash__(self) -> int: 343 # Does not currently take dialect state into account 344 return hash(type(self)) 345 346 def normalize_identifier(self, expression: E) -> E: 347 """ 348 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 349 350 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 351 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 352 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 353 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 354 355 There are also dialects like Spark, which are case-insensitive even when quotes are 356 present, and dialects like MySQL, whose resolution rules match those employed by the 357 underlying operating system, for example they may always be case-sensitive in Linux. 358 359 Finally, the normalization behavior of some engines can even be controlled through flags, 360 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 361 362 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 363 that it can analyze queries in the optimizer and successfully capture their semantics. 364 """ 365 if ( 366 isinstance(expression, exp.Identifier) 367 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 368 and ( 369 not expression.quoted 370 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 371 ) 372 ): 373 expression.set( 374 "this", 375 expression.this.upper() 376 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 377 else expression.this.lower(), 378 ) 379 380 return expression 381 382 def case_sensitive(self, text: str) -> bool: 383 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 384 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 385 return False 386 387 unsafe = ( 388 str.islower 389 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 390 else str.isupper 391 ) 392 return any(unsafe(char) for char in text) 393 394 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 395 """Checks if text can be identified given an identify option. 396 397 Args: 398 text: The text to check. 399 identify: 400 `"always"` or `True`: Always returns `True`. 401 `"safe"`: Only returns `True` if the identifier is case-insensitive. 402 403 Returns: 404 Whether or not the given text can be identified. 405 """ 406 if identify is True or identify == "always": 407 return True 408 409 if identify == "safe": 410 return not self.case_sensitive(text) 411 412 return False 413 414 def quote_identifier(self, expression: E, identify: bool = True) -> E: 415 """ 416 Adds quotes to a given identifier. 417 418 Args: 419 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 420 identify: If set to `False`, the quotes will only be added if the identifier is deemed 421 "unsafe", with respect to its characters and this dialect's normalization strategy. 422 """ 423 if isinstance(expression, exp.Identifier): 424 name = expression.this 425 expression.set( 426 "quoted", 427 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 428 ) 429 430 return expression 431 432 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 433 return self.parser(**opts).parse(self.tokenize(sql), sql) 434 435 def parse_into( 436 self, expression_type: exp.IntoType, sql: str, **opts 437 ) -> t.List[t.Optional[exp.Expression]]: 438 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 439 440 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 441 return self.generator(**opts).generate(expression, copy=copy) 442 443 def transpile(self, sql: str, **opts) -> t.List[str]: 444 return [ 445 self.generate(expression, copy=False, **opts) if expression else "" 446 for expression in self.parse(sql) 447 ] 448 449 def tokenize(self, sql: str) -> t.List[Token]: 450 return self.tokenizer.tokenize(sql) 451 452 @property 453 def tokenizer(self) -> Tokenizer: 454 if not hasattr(self, "_tokenizer"): 455 self._tokenizer = self.tokenizer_class(dialect=self) 456 return self._tokenizer 457 458 def parser(self, **opts) -> Parser: 459 return self.parser_class(dialect=self, **opts) 460 461 def generator(self, **opts) -> Generator: 462 return self.generator_class(dialect=self, **opts) 463 464 465DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 466 467 468def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 469 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 470 471 472def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 473 if expression.args.get("accuracy"): 474 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 475 return self.func("APPROX_COUNT_DISTINCT", expression.this) 476 477 478def if_sql( 479 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 480) -> t.Callable[[Generator, exp.If], str]: 481 def _if_sql(self: Generator, expression: exp.If) -> str: 482 return self.func( 483 name, 484 expression.this, 485 expression.args.get("true"), 486 expression.args.get("false") or false_value, 487 ) 488 489 return _if_sql 490 491 492def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 493 return self.binary(expression, "->") 494 495 496def arrow_json_extract_scalar_sql( 497 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 498) -> str: 499 return self.binary(expression, "->>") 500 501 502def inline_array_sql(self: Generator, expression: exp.Array) -> str: 503 return f"[{self.expressions(expression, flat=True)}]" 504 505 506def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 507 return self.like_sql( 508 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 509 ) 510 511 512def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 513 zone = self.sql(expression, "this") 514 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 515 516 517def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 518 if expression.args.get("recursive"): 519 self.unsupported("Recursive CTEs are unsupported") 520 expression.args["recursive"] = False 521 return self.with_sql(expression) 522 523 524def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 525 n = self.sql(expression, "this") 526 d = self.sql(expression, "expression") 527 return f"IF({d} <> 0, {n} / {d}, NULL)" 528 529 530def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 531 self.unsupported("TABLESAMPLE unsupported") 532 return self.sql(expression.this) 533 534 535def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 536 self.unsupported("PIVOT unsupported") 537 return "" 538 539 540def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 541 return self.cast_sql(expression) 542 543 544def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 545 self.unsupported("Properties unsupported") 546 return "" 547 548 549def no_comment_column_constraint_sql( 550 self: Generator, expression: exp.CommentColumnConstraint 551) -> str: 552 self.unsupported("CommentColumnConstraint unsupported") 553 return "" 554 555 556def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 557 self.unsupported("MAP_FROM_ENTRIES unsupported") 558 return "" 559 560 561def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 562 this = self.sql(expression, "this") 563 substr = self.sql(expression, "substr") 564 position = self.sql(expression, "position") 565 if position: 566 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 567 return f"STRPOS({this}, {substr})" 568 569 570def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 571 return ( 572 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 573 ) 574 575 576def var_map_sql( 577 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 578) -> str: 579 keys = expression.args["keys"] 580 values = expression.args["values"] 581 582 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 583 self.unsupported("Cannot convert array columns into map.") 584 return self.func(map_func_name, keys, values) 585 586 args = [] 587 for key, value in zip(keys.expressions, values.expressions): 588 args.append(self.sql(key)) 589 args.append(self.sql(value)) 590 591 return self.func(map_func_name, *args) 592 593 594def format_time_lambda( 595 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 596) -> t.Callable[[t.List], E]: 597 """Helper used for time expressions. 598 599 Args: 600 exp_class: the expression class to instantiate. 601 dialect: target sql dialect. 602 default: the default format, True being time. 603 604 Returns: 605 A callable that can be used to return the appropriately formatted time expression. 606 """ 607 608 def _format_time(args: t.List): 609 return exp_class( 610 this=seq_get(args, 0), 611 format=Dialect[dialect].format_time( 612 seq_get(args, 1) 613 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 614 ), 615 ) 616 617 return _format_time 618 619 620def time_format( 621 dialect: DialectType = None, 622) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 623 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 624 """ 625 Returns the time format for a given expression, unless it's equivalent 626 to the default time format of the dialect of interest. 627 """ 628 time_format = self.format_time(expression) 629 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 630 631 return _time_format 632 633 634def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 635 """ 636 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 637 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 638 columns are removed from the create statement. 639 """ 640 has_schema = isinstance(expression.this, exp.Schema) 641 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 642 643 if has_schema and is_partitionable: 644 prop = expression.find(exp.PartitionedByProperty) 645 if prop and prop.this and not isinstance(prop.this, exp.Schema): 646 schema = expression.this 647 columns = {v.name.upper() for v in prop.this.expressions} 648 partitions = [col for col in schema.expressions if col.name.upper() in columns] 649 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 650 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 651 expression.set("this", schema) 652 653 return self.create_sql(expression) 654 655 656def parse_date_delta( 657 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 658) -> t.Callable[[t.List], E]: 659 def inner_func(args: t.List) -> E: 660 unit_based = len(args) == 3 661 this = args[2] if unit_based else seq_get(args, 0) 662 unit = args[0] if unit_based else exp.Literal.string("DAY") 663 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 664 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 665 666 return inner_func 667 668 669def parse_date_delta_with_interval( 670 expression_class: t.Type[E], 671) -> t.Callable[[t.List], t.Optional[E]]: 672 def func(args: t.List) -> t.Optional[E]: 673 if len(args) < 2: 674 return None 675 676 interval = args[1] 677 678 if not isinstance(interval, exp.Interval): 679 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 680 681 expression = interval.this 682 if expression and expression.is_string: 683 expression = exp.Literal.number(expression.this) 684 685 return expression_class( 686 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 687 ) 688 689 return func 690 691 692def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 693 unit = seq_get(args, 0) 694 this = seq_get(args, 1) 695 696 if isinstance(this, exp.Cast) and this.is_type("date"): 697 return exp.DateTrunc(unit=unit, this=this) 698 return exp.TimestampTrunc(this=this, unit=unit) 699 700 701def date_add_interval_sql( 702 data_type: str, kind: str 703) -> t.Callable[[Generator, exp.Expression], str]: 704 def func(self: Generator, expression: exp.Expression) -> str: 705 this = self.sql(expression, "this") 706 unit = expression.args.get("unit") 707 unit = exp.var(unit.name.upper() if unit else "DAY") 708 interval = exp.Interval(this=expression.expression, unit=unit) 709 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 710 711 return func 712 713 714def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 715 return self.func( 716 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 717 ) 718 719 720def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 721 if not expression.expression: 722 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 723 if expression.text("expression").lower() in TIMEZONES: 724 return self.sql( 725 exp.AtTimeZone( 726 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 727 zone=expression.expression, 728 ) 729 ) 730 return self.function_fallback_sql(expression) 731 732 733def locate_to_strposition(args: t.List) -> exp.Expression: 734 return exp.StrPosition( 735 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 736 ) 737 738 739def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 740 return self.func( 741 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 742 ) 743 744 745def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 746 return self.sql( 747 exp.Substring( 748 this=expression.this, start=exp.Literal.number(1), length=expression.expression 749 ) 750 ) 751 752 753def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 754 return self.sql( 755 exp.Substring( 756 this=expression.this, 757 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 758 ) 759 ) 760 761 762def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 763 return self.sql(exp.cast(expression.this, "timestamp")) 764 765 766def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 767 return self.sql(exp.cast(expression.this, "date")) 768 769 770# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 771def encode_decode_sql( 772 self: Generator, expression: exp.Expression, name: str, replace: bool = True 773) -> str: 774 charset = expression.args.get("charset") 775 if charset and charset.name.lower() != "utf-8": 776 self.unsupported(f"Expected utf-8 character set, got {charset}.") 777 778 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 779 780 781def min_or_least(self: Generator, expression: exp.Min) -> str: 782 name = "LEAST" if expression.expressions else "MIN" 783 return rename_func(name)(self, expression) 784 785 786def max_or_greatest(self: Generator, expression: exp.Max) -> str: 787 name = "GREATEST" if expression.expressions else "MAX" 788 return rename_func(name)(self, expression) 789 790 791def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 792 cond = expression.this 793 794 if isinstance(expression.this, exp.Distinct): 795 cond = expression.this.expressions[0] 796 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 797 798 return self.func("sum", exp.func("if", cond, 1, 0)) 799 800 801def trim_sql(self: Generator, expression: exp.Trim) -> str: 802 target = self.sql(expression, "this") 803 trim_type = self.sql(expression, "position") 804 remove_chars = self.sql(expression, "expression") 805 collation = self.sql(expression, "collation") 806 807 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 808 if not remove_chars and not collation: 809 return self.trim_sql(expression) 810 811 trim_type = f"{trim_type} " if trim_type else "" 812 remove_chars = f"{remove_chars} " if remove_chars else "" 813 from_part = "FROM " if trim_type or remove_chars else "" 814 collation = f" COLLATE {collation}" if collation else "" 815 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 816 817 818def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 819 return self.func("STRPTIME", expression.this, self.format_time(expression)) 820 821 822def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 823 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 824 _dialect = Dialect.get_or_raise(dialect) 825 time_format = self.format_time(expression) 826 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 827 return self.sql( 828 exp.cast( 829 exp.StrToTime(this=expression.this, format=expression.args["format"]), 830 "date", 831 ) 832 ) 833 return self.sql(exp.cast(expression.this, "date")) 834 835 return _ts_or_ds_to_date_sql 836 837 838def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 839 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 840 841 842def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 843 delim, *rest_args = expression.expressions 844 return self.sql( 845 reduce( 846 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 847 rest_args, 848 ) 849 ) 850 851 852def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 853 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 854 if bad_args: 855 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 856 857 return self.func( 858 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 859 ) 860 861 862def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 863 bad_args = list( 864 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 865 ) 866 if bad_args: 867 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 868 869 return self.func( 870 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 871 ) 872 873 874def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 875 names = [] 876 for agg in aggregations: 877 if isinstance(agg, exp.Alias): 878 names.append(agg.alias) 879 else: 880 """ 881 This case corresponds to aggregations without aliases being used as suffixes 882 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 883 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 884 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 885 """ 886 agg_all_unquoted = agg.transform( 887 lambda node: exp.Identifier(this=node.name, quoted=False) 888 if isinstance(node, exp.Identifier) 889 else node 890 ) 891 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 892 893 return names 894 895 896def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 897 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 898 899 900# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 901def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 902 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 903 904 905def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 906 return self.func("MAX", expression.this) 907 908 909def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 910 a = self.sql(expression.left) 911 b = self.sql(expression.right) 912 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 913 914 915# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 916def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 917 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 918 919 920def is_parse_json(expression: exp.Expression) -> bool: 921 return isinstance(expression, exp.ParseJSON) or ( 922 isinstance(expression, exp.Cast) and expression.is_type("json") 923 ) 924 925 926def isnull_to_is_null(args: t.List) -> exp.Expression: 927 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 928 929 930def generatedasidentitycolumnconstraint_sql( 931 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 932) -> str: 933 start = self.sql(expression, "start") or "1" 934 increment = self.sql(expression, "increment") or "1" 935 return f"IDENTITY({start}, {increment})" 936 937 938def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 939 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 940 if expression.args.get("count"): 941 self.unsupported(f"Only two arguments are supported in function {name}.") 942 943 return self.func(name, expression.this, expression.expression) 944 945 return _arg_max_or_min_sql 946 947 948def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 949 this = expression.this.copy() 950 951 return_type = expression.return_type 952 if return_type.is_type(exp.DataType.Type.DATE): 953 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 954 # can truncate timestamp strings, because some dialects can't cast them to DATE 955 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 956 957 expression.this.replace(exp.cast(this, return_type)) 958 return expression 959 960 961def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 962 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 963 if cast and isinstance(expression, exp.TsOrDsAdd): 964 expression = ts_or_ds_add_cast(expression) 965 966 return self.func( 967 name, 968 exp.var(expression.text("unit").upper() or "DAY"), 969 expression.expression, 970 expression.this, 971 ) 972 973 return _delta_sql
24class Dialects(str, Enum): 25 """Dialects supported by SQLGLot.""" 26 27 DIALECT = "" 28 29 BIGQUERY = "bigquery" 30 CLICKHOUSE = "clickhouse" 31 DATABRICKS = "databricks" 32 DORIS = "doris" 33 DRILL = "drill" 34 DUCKDB = "duckdb" 35 HIVE = "hive" 36 MYSQL = "mysql" 37 ORACLE = "oracle" 38 POSTGRES = "postgres" 39 PRESTO = "presto" 40 REDSHIFT = "redshift" 41 SNOWFLAKE = "snowflake" 42 SPARK = "spark" 43 SPARK2 = "spark2" 44 SQLITE = "sqlite" 45 STARROCKS = "starrocks" 46 TABLEAU = "tableau" 47 TERADATA = "teradata" 48 TRINO = "trino" 49 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
52class NormalizationStrategy(str, AutoName): 53 """Specifies the strategy according to which identifiers should be normalized.""" 54 55 LOWERCASE = auto() 56 """Unquoted identifiers are lowercased.""" 57 58 UPPERCASE = auto() 59 """Unquoted identifiers are uppercased.""" 60 61 CASE_SENSITIVE = auto() 62 """Always case-sensitive, regardless of quotes.""" 63 64 CASE_INSENSITIVE = auto() 65 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
144class Dialect(metaclass=_Dialect): 145 INDEX_OFFSET = 0 146 """Determines the base index offset for arrays.""" 147 148 WEEK_OFFSET = 0 149 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 150 151 UNNEST_COLUMN_ONLY = False 152 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 153 154 ALIAS_POST_TABLESAMPLE = False 155 """Determines whether or not the table alias comes after tablesample.""" 156 157 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 158 """Specifies the strategy according to which identifiers should be normalized.""" 159 160 IDENTIFIERS_CAN_START_WITH_DIGIT = False 161 """Determines whether or not an unquoted identifier can start with a digit.""" 162 163 DPIPE_IS_STRING_CONCAT = True 164 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 165 166 STRICT_STRING_CONCAT = False 167 """Determines whether or not `CONCAT`'s arguments must be strings.""" 168 169 SUPPORTS_USER_DEFINED_TYPES = True 170 """Determines whether or not user-defined data types are supported.""" 171 172 SUPPORTS_SEMI_ANTI_JOIN = True 173 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 174 175 NORMALIZE_FUNCTIONS: bool | str = "upper" 176 """Determines how function names are going to be normalized.""" 177 178 LOG_BASE_FIRST = True 179 """Determines whether the base comes first in the `LOG` function.""" 180 181 NULL_ORDERING = "nulls_are_small" 182 """ 183 Indicates the default `NULL` ordering method to use if not explicitly set. 184 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 185 """ 186 187 TYPED_DIVISION = False 188 """ 189 Whether the behavior of `a / b` depends on the types of `a` and `b`. 190 False means `a / b` is always float division. 191 True means `a / b` is integer division if both `a` and `b` are integers. 192 """ 193 194 SAFE_DIVISION = False 195 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 196 197 CONCAT_COALESCE = False 198 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 199 200 DATE_FORMAT = "'%Y-%m-%d'" 201 DATEINT_FORMAT = "'%Y%m%d'" 202 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 203 204 TIME_MAPPING: t.Dict[str, str] = {} 205 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 206 207 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 208 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 209 FORMAT_MAPPING: t.Dict[str, str] = {} 210 """ 211 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 212 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 213 """ 214 215 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 216 """Mapping of an unescaped escape sequence to the corresponding character.""" 217 218 PSEUDOCOLUMNS: t.Set[str] = set() 219 """ 220 Columns that are auto-generated by the engine corresponding to this dialect. 221 For example, such columns may be excluded from `SELECT *` queries. 222 """ 223 224 PREFER_CTE_ALIAS_COLUMN = False 225 """ 226 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 227 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 228 any projection aliases in the subquery. 229 230 For example, 231 WITH y(c) AS ( 232 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 233 ) SELECT c FROM y; 234 235 will be rewritten as 236 237 WITH y(c) AS ( 238 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 239 ) SELECT c FROM y; 240 """ 241 242 # --- Autofilled --- 243 244 tokenizer_class = Tokenizer 245 parser_class = Parser 246 generator_class = Generator 247 248 # A trie of the time_mapping keys 249 TIME_TRIE: t.Dict = {} 250 FORMAT_TRIE: t.Dict = {} 251 252 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 253 INVERSE_TIME_TRIE: t.Dict = {} 254 255 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 256 257 # Delimiters for quotes, identifiers and the corresponding escape characters 258 QUOTE_START = "'" 259 QUOTE_END = "'" 260 IDENTIFIER_START = '"' 261 IDENTIFIER_END = '"' 262 263 # Delimiters for bit, hex, byte and unicode literals 264 BIT_START: t.Optional[str] = None 265 BIT_END: t.Optional[str] = None 266 HEX_START: t.Optional[str] = None 267 HEX_END: t.Optional[str] = None 268 BYTE_START: t.Optional[str] = None 269 BYTE_END: t.Optional[str] = None 270 UNICODE_START: t.Optional[str] = None 271 UNICODE_END: t.Optional[str] = None 272 273 @classmethod 274 def get_or_raise(cls, dialect: DialectType) -> Dialect: 275 """ 276 Look up a dialect in the global dialect registry and return it if it exists. 277 278 Args: 279 dialect: The target dialect. If this is a string, it can be optionally followed by 280 additional key-value pairs that are separated by commas and are used to specify 281 dialect settings, such as whether the dialect's identifiers are case-sensitive. 282 283 Example: 284 >>> dialect = dialect_class = get_or_raise("duckdb") 285 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 286 287 Returns: 288 The corresponding Dialect instance. 289 """ 290 291 if not dialect: 292 return cls() 293 if isinstance(dialect, _Dialect): 294 return dialect() 295 if isinstance(dialect, Dialect): 296 return dialect 297 if isinstance(dialect, str): 298 try: 299 dialect_name, *kv_pairs = dialect.split(",") 300 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 301 except ValueError: 302 raise ValueError( 303 f"Invalid dialect format: '{dialect}'. " 304 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 305 ) 306 307 result = cls.get(dialect_name.strip()) 308 if not result: 309 raise ValueError(f"Unknown dialect '{dialect_name}'.") 310 311 return result(**kwargs) 312 313 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 314 315 @classmethod 316 def format_time( 317 cls, expression: t.Optional[str | exp.Expression] 318 ) -> t.Optional[exp.Expression]: 319 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 320 if isinstance(expression, str): 321 return exp.Literal.string( 322 # the time formats are quoted 323 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 324 ) 325 326 if expression and expression.is_string: 327 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 328 329 return expression 330 331 def __init__(self, **kwargs) -> None: 332 normalization_strategy = kwargs.get("normalization_strategy") 333 334 if normalization_strategy is None: 335 self.normalization_strategy = self.NORMALIZATION_STRATEGY 336 else: 337 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 338 339 def __eq__(self, other: t.Any) -> bool: 340 # Does not currently take dialect state into account 341 return type(self) == other 342 343 def __hash__(self) -> int: 344 # Does not currently take dialect state into account 345 return hash(type(self)) 346 347 def normalize_identifier(self, expression: E) -> E: 348 """ 349 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 350 351 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 352 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 353 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 354 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 355 356 There are also dialects like Spark, which are case-insensitive even when quotes are 357 present, and dialects like MySQL, whose resolution rules match those employed by the 358 underlying operating system, for example they may always be case-sensitive in Linux. 359 360 Finally, the normalization behavior of some engines can even be controlled through flags, 361 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 362 363 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 364 that it can analyze queries in the optimizer and successfully capture their semantics. 365 """ 366 if ( 367 isinstance(expression, exp.Identifier) 368 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 369 and ( 370 not expression.quoted 371 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 372 ) 373 ): 374 expression.set( 375 "this", 376 expression.this.upper() 377 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 378 else expression.this.lower(), 379 ) 380 381 return expression 382 383 def case_sensitive(self, text: str) -> bool: 384 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 385 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 386 return False 387 388 unsafe = ( 389 str.islower 390 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 391 else str.isupper 392 ) 393 return any(unsafe(char) for char in text) 394 395 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 396 """Checks if text can be identified given an identify option. 397 398 Args: 399 text: The text to check. 400 identify: 401 `"always"` or `True`: Always returns `True`. 402 `"safe"`: Only returns `True` if the identifier is case-insensitive. 403 404 Returns: 405 Whether or not the given text can be identified. 406 """ 407 if identify is True or identify == "always": 408 return True 409 410 if identify == "safe": 411 return not self.case_sensitive(text) 412 413 return False 414 415 def quote_identifier(self, expression: E, identify: bool = True) -> E: 416 """ 417 Adds quotes to a given identifier. 418 419 Args: 420 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 421 identify: If set to `False`, the quotes will only be added if the identifier is deemed 422 "unsafe", with respect to its characters and this dialect's normalization strategy. 423 """ 424 if isinstance(expression, exp.Identifier): 425 name = expression.this 426 expression.set( 427 "quoted", 428 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 429 ) 430 431 return expression 432 433 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 434 return self.parser(**opts).parse(self.tokenize(sql), sql) 435 436 def parse_into( 437 self, expression_type: exp.IntoType, sql: str, **opts 438 ) -> t.List[t.Optional[exp.Expression]]: 439 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 440 441 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 442 return self.generator(**opts).generate(expression, copy=copy) 443 444 def transpile(self, sql: str, **opts) -> t.List[str]: 445 return [ 446 self.generate(expression, copy=False, **opts) if expression else "" 447 for expression in self.parse(sql) 448 ] 449 450 def tokenize(self, sql: str) -> t.List[Token]: 451 return self.tokenizer.tokenize(sql) 452 453 @property 454 def tokenizer(self) -> Tokenizer: 455 if not hasattr(self, "_tokenizer"): 456 self._tokenizer = self.tokenizer_class(dialect=self) 457 return self._tokenizer 458 459 def parser(self, **opts) -> Parser: 460 return self.parser_class(dialect=self, **opts) 461 462 def generator(self, **opts) -> Generator: 463 return self.generator_class(dialect=self, **opts)
331 def __init__(self, **kwargs) -> None: 332 normalization_strategy = kwargs.get("normalization_strategy") 333 334 if normalization_strategy is None: 335 self.normalization_strategy = self.NORMALIZATION_STRATEGY 336 else: 337 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Determines whether or not UNNEST
table aliases are treated as column aliases.
Specifies the strategy according to which identifiers should be normalized.
Determines whether or not an unquoted identifier can start with a digit.
Determines whether or not the DPIPE token (||
) is a string concatenation operator.
Indicates the default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
Determines whether division by zero throws an error (False
) or returns NULL (True
).
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
format.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
273 @classmethod 274 def get_or_raise(cls, dialect: DialectType) -> Dialect: 275 """ 276 Look up a dialect in the global dialect registry and return it if it exists. 277 278 Args: 279 dialect: The target dialect. If this is a string, it can be optionally followed by 280 additional key-value pairs that are separated by commas and are used to specify 281 dialect settings, such as whether the dialect's identifiers are case-sensitive. 282 283 Example: 284 >>> dialect = dialect_class = get_or_raise("duckdb") 285 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 286 287 Returns: 288 The corresponding Dialect instance. 289 """ 290 291 if not dialect: 292 return cls() 293 if isinstance(dialect, _Dialect): 294 return dialect() 295 if isinstance(dialect, Dialect): 296 return dialect 297 if isinstance(dialect, str): 298 try: 299 dialect_name, *kv_pairs = dialect.split(",") 300 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 301 except ValueError: 302 raise ValueError( 303 f"Invalid dialect format: '{dialect}'. " 304 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 305 ) 306 307 result = cls.get(dialect_name.strip()) 308 if not result: 309 raise ValueError(f"Unknown dialect '{dialect_name}'.") 310 311 return result(**kwargs) 312 313 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
315 @classmethod 316 def format_time( 317 cls, expression: t.Optional[str | exp.Expression] 318 ) -> t.Optional[exp.Expression]: 319 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 320 if isinstance(expression, str): 321 return exp.Literal.string( 322 # the time formats are quoted 323 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 324 ) 325 326 if expression and expression.is_string: 327 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 328 329 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
347 def normalize_identifier(self, expression: E) -> E: 348 """ 349 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 350 351 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 352 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 353 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 354 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 355 356 There are also dialects like Spark, which are case-insensitive even when quotes are 357 present, and dialects like MySQL, whose resolution rules match those employed by the 358 underlying operating system, for example they may always be case-sensitive in Linux. 359 360 Finally, the normalization behavior of some engines can even be controlled through flags, 361 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 362 363 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 364 that it can analyze queries in the optimizer and successfully capture their semantics. 365 """ 366 if ( 367 isinstance(expression, exp.Identifier) 368 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 369 and ( 370 not expression.quoted 371 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 372 ) 373 ): 374 expression.set( 375 "this", 376 expression.this.upper() 377 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 378 else expression.this.lower(), 379 ) 380 381 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
383 def case_sensitive(self, text: str) -> bool: 384 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 385 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 386 return False 387 388 unsafe = ( 389 str.islower 390 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 391 else str.isupper 392 ) 393 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
395 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 396 """Checks if text can be identified given an identify option. 397 398 Args: 399 text: The text to check. 400 identify: 401 `"always"` or `True`: Always returns `True`. 402 `"safe"`: Only returns `True` if the identifier is case-insensitive. 403 404 Returns: 405 Whether or not the given text can be identified. 406 """ 407 if identify is True or identify == "always": 408 return True 409 410 if identify == "safe": 411 return not self.case_sensitive(text) 412 413 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
415 def quote_identifier(self, expression: E, identify: bool = True) -> E: 416 """ 417 Adds quotes to a given identifier. 418 419 Args: 420 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 421 identify: If set to `False`, the quotes will only be added if the identifier is deemed 422 "unsafe", with respect to its characters and this dialect's normalization strategy. 423 """ 424 if isinstance(expression, exp.Identifier): 425 name = expression.this 426 expression.set( 427 "quoted", 428 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 429 ) 430 431 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
479def if_sql( 480 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 481) -> t.Callable[[Generator, exp.If], str]: 482 def _if_sql(self: Generator, expression: exp.If) -> str: 483 return self.func( 484 name, 485 expression.this, 486 expression.args.get("true"), 487 expression.args.get("false") or false_value, 488 ) 489 490 return _if_sql
562def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 563 this = self.sql(expression, "this") 564 substr = self.sql(expression, "substr") 565 position = self.sql(expression, "position") 566 if position: 567 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 568 return f"STRPOS({this}, {substr})"
577def var_map_sql( 578 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 579) -> str: 580 keys = expression.args["keys"] 581 values = expression.args["values"] 582 583 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 584 self.unsupported("Cannot convert array columns into map.") 585 return self.func(map_func_name, keys, values) 586 587 args = [] 588 for key, value in zip(keys.expressions, values.expressions): 589 args.append(self.sql(key)) 590 args.append(self.sql(value)) 591 592 return self.func(map_func_name, *args)
595def format_time_lambda( 596 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 597) -> t.Callable[[t.List], E]: 598 """Helper used for time expressions. 599 600 Args: 601 exp_class: the expression class to instantiate. 602 dialect: target sql dialect. 603 default: the default format, True being time. 604 605 Returns: 606 A callable that can be used to return the appropriately formatted time expression. 607 """ 608 609 def _format_time(args: t.List): 610 return exp_class( 611 this=seq_get(args, 0), 612 format=Dialect[dialect].format_time( 613 seq_get(args, 1) 614 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 615 ), 616 ) 617 618 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
621def time_format( 622 dialect: DialectType = None, 623) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 624 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 625 """ 626 Returns the time format for a given expression, unless it's equivalent 627 to the default time format of the dialect of interest. 628 """ 629 time_format = self.format_time(expression) 630 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 631 632 return _time_format
635def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 636 """ 637 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 638 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 639 columns are removed from the create statement. 640 """ 641 has_schema = isinstance(expression.this, exp.Schema) 642 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 643 644 if has_schema and is_partitionable: 645 prop = expression.find(exp.PartitionedByProperty) 646 if prop and prop.this and not isinstance(prop.this, exp.Schema): 647 schema = expression.this 648 columns = {v.name.upper() for v in prop.this.expressions} 649 partitions = [col for col in schema.expressions if col.name.upper() in columns] 650 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 651 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 652 expression.set("this", schema) 653 654 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
657def parse_date_delta( 658 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 659) -> t.Callable[[t.List], E]: 660 def inner_func(args: t.List) -> E: 661 unit_based = len(args) == 3 662 this = args[2] if unit_based else seq_get(args, 0) 663 unit = args[0] if unit_based else exp.Literal.string("DAY") 664 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 665 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 666 667 return inner_func
670def parse_date_delta_with_interval( 671 expression_class: t.Type[E], 672) -> t.Callable[[t.List], t.Optional[E]]: 673 def func(args: t.List) -> t.Optional[E]: 674 if len(args) < 2: 675 return None 676 677 interval = args[1] 678 679 if not isinstance(interval, exp.Interval): 680 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 681 682 expression = interval.this 683 if expression and expression.is_string: 684 expression = exp.Literal.number(expression.this) 685 686 return expression_class( 687 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 688 ) 689 690 return func
702def date_add_interval_sql( 703 data_type: str, kind: str 704) -> t.Callable[[Generator, exp.Expression], str]: 705 def func(self: Generator, expression: exp.Expression) -> str: 706 this = self.sql(expression, "this") 707 unit = expression.args.get("unit") 708 unit = exp.var(unit.name.upper() if unit else "DAY") 709 interval = exp.Interval(this=expression.expression, unit=unit) 710 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 711 712 return func
721def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 722 if not expression.expression: 723 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 724 if expression.text("expression").lower() in TIMEZONES: 725 return self.sql( 726 exp.AtTimeZone( 727 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 728 zone=expression.expression, 729 ) 730 ) 731 return self.function_fallback_sql(expression)
772def encode_decode_sql( 773 self: Generator, expression: exp.Expression, name: str, replace: bool = True 774) -> str: 775 charset = expression.args.get("charset") 776 if charset and charset.name.lower() != "utf-8": 777 self.unsupported(f"Expected utf-8 character set, got {charset}.") 778 779 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
792def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 793 cond = expression.this 794 795 if isinstance(expression.this, exp.Distinct): 796 cond = expression.this.expressions[0] 797 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 798 799 return self.func("sum", exp.func("if", cond, 1, 0))
802def trim_sql(self: Generator, expression: exp.Trim) -> str: 803 target = self.sql(expression, "this") 804 trim_type = self.sql(expression, "position") 805 remove_chars = self.sql(expression, "expression") 806 collation = self.sql(expression, "collation") 807 808 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 809 if not remove_chars and not collation: 810 return self.trim_sql(expression) 811 812 trim_type = f"{trim_type} " if trim_type else "" 813 remove_chars = f"{remove_chars} " if remove_chars else "" 814 from_part = "FROM " if trim_type or remove_chars else "" 815 collation = f" COLLATE {collation}" if collation else "" 816 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
823def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 824 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 825 _dialect = Dialect.get_or_raise(dialect) 826 time_format = self.format_time(expression) 827 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 828 return self.sql( 829 exp.cast( 830 exp.StrToTime(this=expression.this, format=expression.args["format"]), 831 "date", 832 ) 833 ) 834 return self.sql(exp.cast(expression.this, "date")) 835 836 return _ts_or_ds_to_date_sql
853def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 854 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 855 if bad_args: 856 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 857 858 return self.func( 859 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 860 )
863def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 864 bad_args = list( 865 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 866 ) 867 if bad_args: 868 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 869 870 return self.func( 871 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 872 )
875def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 876 names = [] 877 for agg in aggregations: 878 if isinstance(agg, exp.Alias): 879 names.append(agg.alias) 880 else: 881 """ 882 This case corresponds to aggregations without aliases being used as suffixes 883 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 884 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 885 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 886 """ 887 agg_all_unquoted = agg.transform( 888 lambda node: exp.Identifier(this=node.name, quoted=False) 889 if isinstance(node, exp.Identifier) 890 else node 891 ) 892 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 893 894 return names
939def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 940 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 941 if expression.args.get("count"): 942 self.unsupported(f"Only two arguments are supported in function {name}.") 943 944 return self.func(name, expression.this, expression.expression) 945 946 return _arg_max_or_min_sql
949def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 950 this = expression.this.copy() 951 952 return_type = expression.return_type 953 if return_type.is_type(exp.DataType.Type.DATE): 954 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 955 # can truncate timestamp strings, because some dialects can't cast them to DATE 956 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 957 958 expression.this.replace(exp.cast(this, return_type)) 959 return expression
962def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 963 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 964 if cast and isinstance(expression, exp.TsOrDsAdd): 965 expression = ts_or_ds_add_cast(expression) 966 967 return self.func( 968 name, 969 exp.var(expression.text("unit").upper() or "DAY"), 970 expression.expression, 971 expression.this, 972 ) 973 974 return _delta_sql