sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 114 115 base = seq_get(bases, 0) 116 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 117 base_parser = (getattr(base, "parser_class", Parser),) 118 base_generator = (getattr(base, "generator_class", Generator),) 119 120 klass.tokenizer_class = klass.__dict__.get( 121 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 122 ) 123 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 124 klass.generator_class = klass.__dict__.get( 125 "Generator", type("Generator", base_generator, {}) 126 ) 127 128 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 129 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 130 klass.tokenizer_class._IDENTIFIERS.items() 131 )[0] 132 133 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 134 return next( 135 ( 136 (s, e) 137 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 138 if t == token_type 139 ), 140 (None, None), 141 ) 142 143 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 144 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 145 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 146 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 147 148 if enum not in ("", "bigquery"): 149 klass.generator_class.SELECT_KINDS = () 150 151 if enum not in ("", "databricks", "hive", "spark", "spark2"): 152 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 153 for modifier in ("cluster", "distribute", "sort"): 154 modifier_transforms.pop(modifier, None) 155 156 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 157 158 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 159 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 160 TokenType.ANTI, 161 TokenType.SEMI, 162 } 163 164 return klass 165 166 167class Dialect(metaclass=_Dialect): 168 INDEX_OFFSET = 0 169 """The base index offset for arrays.""" 170 171 WEEK_OFFSET = 0 172 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 173 174 UNNEST_COLUMN_ONLY = False 175 """Whether `UNNEST` table aliases are treated as column aliases.""" 176 177 ALIAS_POST_TABLESAMPLE = False 178 """Whether the table alias comes after tablesample.""" 179 180 TABLESAMPLE_SIZE_IS_PERCENT = False 181 """Whether a size in the table sample clause represents percentage.""" 182 183 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 184 """Specifies the strategy according to which identifiers should be normalized.""" 185 186 IDENTIFIERS_CAN_START_WITH_DIGIT = False 187 """Whether an unquoted identifier can start with a digit.""" 188 189 DPIPE_IS_STRING_CONCAT = True 190 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 191 192 STRICT_STRING_CONCAT = False 193 """Whether `CONCAT`'s arguments must be strings.""" 194 195 SUPPORTS_USER_DEFINED_TYPES = True 196 """Whether user-defined data types are supported.""" 197 198 SUPPORTS_SEMI_ANTI_JOIN = True 199 """Whether `SEMI` or `ANTI` joins are supported.""" 200 201 NORMALIZE_FUNCTIONS: bool | str = "upper" 202 """ 203 Determines how function names are going to be normalized. 204 Possible values: 205 "upper" or True: Convert names to uppercase. 206 "lower": Convert names to lowercase. 207 False: Disables function name normalization. 208 """ 209 210 LOG_BASE_FIRST: t.Optional[bool] = True 211 """ 212 Whether the base comes first in the `LOG` function. 213 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 214 """ 215 216 NULL_ORDERING = "nulls_are_small" 217 """ 218 Default `NULL` ordering method to use if not explicitly set. 219 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 220 """ 221 222 TYPED_DIVISION = False 223 """ 224 Whether the behavior of `a / b` depends on the types of `a` and `b`. 225 False means `a / b` is always float division. 226 True means `a / b` is integer division if both `a` and `b` are integers. 227 """ 228 229 SAFE_DIVISION = False 230 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 231 232 CONCAT_COALESCE = False 233 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 234 235 DATE_FORMAT = "'%Y-%m-%d'" 236 DATEINT_FORMAT = "'%Y%m%d'" 237 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 238 239 TIME_MAPPING: t.Dict[str, str] = {} 240 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 241 242 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 243 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 244 FORMAT_MAPPING: t.Dict[str, str] = {} 245 """ 246 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 247 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 248 """ 249 250 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 251 """Mapping of an unescaped escape sequence to the corresponding character.""" 252 253 PSEUDOCOLUMNS: t.Set[str] = set() 254 """ 255 Columns that are auto-generated by the engine corresponding to this dialect. 256 For example, such columns may be excluded from `SELECT *` queries. 257 """ 258 259 PREFER_CTE_ALIAS_COLUMN = False 260 """ 261 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 262 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 263 any projection aliases in the subquery. 264 265 For example, 266 WITH y(c) AS ( 267 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 268 ) SELECT c FROM y; 269 270 will be rewritten as 271 272 WITH y(c) AS ( 273 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 274 ) SELECT c FROM y; 275 """ 276 277 # --- Autofilled --- 278 279 tokenizer_class = Tokenizer 280 parser_class = Parser 281 generator_class = Generator 282 283 # A trie of the time_mapping keys 284 TIME_TRIE: t.Dict = {} 285 FORMAT_TRIE: t.Dict = {} 286 287 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 288 INVERSE_TIME_TRIE: t.Dict = {} 289 290 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 291 292 # Delimiters for string literals and identifiers 293 QUOTE_START = "'" 294 QUOTE_END = "'" 295 IDENTIFIER_START = '"' 296 IDENTIFIER_END = '"' 297 298 # Delimiters for bit, hex, byte and unicode literals 299 BIT_START: t.Optional[str] = None 300 BIT_END: t.Optional[str] = None 301 HEX_START: t.Optional[str] = None 302 HEX_END: t.Optional[str] = None 303 BYTE_START: t.Optional[str] = None 304 BYTE_END: t.Optional[str] = None 305 UNICODE_START: t.Optional[str] = None 306 UNICODE_END: t.Optional[str] = None 307 308 @classmethod 309 def get_or_raise(cls, dialect: DialectType) -> Dialect: 310 """ 311 Look up a dialect in the global dialect registry and return it if it exists. 312 313 Args: 314 dialect: The target dialect. If this is a string, it can be optionally followed by 315 additional key-value pairs that are separated by commas and are used to specify 316 dialect settings, such as whether the dialect's identifiers are case-sensitive. 317 318 Example: 319 >>> dialect = dialect_class = get_or_raise("duckdb") 320 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 321 322 Returns: 323 The corresponding Dialect instance. 324 """ 325 326 if not dialect: 327 return cls() 328 if isinstance(dialect, _Dialect): 329 return dialect() 330 if isinstance(dialect, Dialect): 331 return dialect 332 if isinstance(dialect, str): 333 try: 334 dialect_name, *kv_pairs = dialect.split(",") 335 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 336 except ValueError: 337 raise ValueError( 338 f"Invalid dialect format: '{dialect}'. " 339 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 340 ) 341 342 result = cls.get(dialect_name.strip()) 343 if not result: 344 from difflib import get_close_matches 345 346 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 347 if similar: 348 similar = f" Did you mean {similar}?" 349 350 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 351 352 return result(**kwargs) 353 354 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 355 356 @classmethod 357 def format_time( 358 cls, expression: t.Optional[str | exp.Expression] 359 ) -> t.Optional[exp.Expression]: 360 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 361 if isinstance(expression, str): 362 return exp.Literal.string( 363 # the time formats are quoted 364 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 365 ) 366 367 if expression and expression.is_string: 368 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 369 370 return expression 371 372 def __init__(self, **kwargs) -> None: 373 normalization_strategy = kwargs.get("normalization_strategy") 374 375 if normalization_strategy is None: 376 self.normalization_strategy = self.NORMALIZATION_STRATEGY 377 else: 378 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 379 380 def __eq__(self, other: t.Any) -> bool: 381 # Does not currently take dialect state into account 382 return type(self) == other 383 384 def __hash__(self) -> int: 385 # Does not currently take dialect state into account 386 return hash(type(self)) 387 388 def normalize_identifier(self, expression: E) -> E: 389 """ 390 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 391 392 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 393 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 394 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 395 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 396 397 There are also dialects like Spark, which are case-insensitive even when quotes are 398 present, and dialects like MySQL, whose resolution rules match those employed by the 399 underlying operating system, for example they may always be case-sensitive in Linux. 400 401 Finally, the normalization behavior of some engines can even be controlled through flags, 402 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 403 404 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 405 that it can analyze queries in the optimizer and successfully capture their semantics. 406 """ 407 if ( 408 isinstance(expression, exp.Identifier) 409 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 410 and ( 411 not expression.quoted 412 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 413 ) 414 ): 415 expression.set( 416 "this", 417 ( 418 expression.this.upper() 419 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 420 else expression.this.lower() 421 ), 422 ) 423 424 return expression 425 426 def case_sensitive(self, text: str) -> bool: 427 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 428 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 429 return False 430 431 unsafe = ( 432 str.islower 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else str.isupper 435 ) 436 return any(unsafe(char) for char in text) 437 438 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 439 """Checks if text can be identified given an identify option. 440 441 Args: 442 text: The text to check. 443 identify: 444 `"always"` or `True`: Always returns `True`. 445 `"safe"`: Only returns `True` if the identifier is case-insensitive. 446 447 Returns: 448 Whether the given text can be identified. 449 """ 450 if identify is True or identify == "always": 451 return True 452 453 if identify == "safe": 454 return not self.case_sensitive(text) 455 456 return False 457 458 def quote_identifier(self, expression: E, identify: bool = True) -> E: 459 """ 460 Adds quotes to a given identifier. 461 462 Args: 463 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 464 identify: If set to `False`, the quotes will only be added if the identifier is deemed 465 "unsafe", with respect to its characters and this dialect's normalization strategy. 466 """ 467 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 468 name = expression.this 469 expression.set( 470 "quoted", 471 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 472 ) 473 474 return expression 475 476 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 477 if isinstance(path, exp.Literal): 478 path_text = path.name 479 if path.is_number: 480 path_text = f"[{path_text}]" 481 482 try: 483 return parse_json_path(path_text) 484 except ParseError as e: 485 logger.warning(f"Invalid JSON path syntax. {str(e)}") 486 487 return path 488 489 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 490 return self.parser(**opts).parse(self.tokenize(sql), sql) 491 492 def parse_into( 493 self, expression_type: exp.IntoType, sql: str, **opts 494 ) -> t.List[t.Optional[exp.Expression]]: 495 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 496 497 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 498 return self.generator(**opts).generate(expression, copy=copy) 499 500 def transpile(self, sql: str, **opts) -> t.List[str]: 501 return [ 502 self.generate(expression, copy=False, **opts) if expression else "" 503 for expression in self.parse(sql) 504 ] 505 506 def tokenize(self, sql: str) -> t.List[Token]: 507 return self.tokenizer.tokenize(sql) 508 509 @property 510 def tokenizer(self) -> Tokenizer: 511 if not hasattr(self, "_tokenizer"): 512 self._tokenizer = self.tokenizer_class(dialect=self) 513 return self._tokenizer 514 515 def parser(self, **opts) -> Parser: 516 return self.parser_class(dialect=self, **opts) 517 518 def generator(self, **opts) -> Generator: 519 return self.generator_class(dialect=self, **opts) 520 521 522DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 523 524 525def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 526 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 527 528 529def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 530 if expression.args.get("accuracy"): 531 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 532 return self.func("APPROX_COUNT_DISTINCT", expression.this) 533 534 535def if_sql( 536 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 537) -> t.Callable[[Generator, exp.If], str]: 538 def _if_sql(self: Generator, expression: exp.If) -> str: 539 return self.func( 540 name, 541 expression.this, 542 expression.args.get("true"), 543 expression.args.get("false") or false_value, 544 ) 545 546 return _if_sql 547 548 549def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 550 this = expression.this 551 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 552 this.replace(exp.cast(this, "json")) 553 554 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 555 556 557def inline_array_sql(self: Generator, expression: exp.Array) -> str: 558 return f"[{self.expressions(expression, flat=True)}]" 559 560 561def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 562 return self.like_sql( 563 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 564 ) 565 566 567def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 568 zone = self.sql(expression, "this") 569 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 570 571 572def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 573 if expression.args.get("recursive"): 574 self.unsupported("Recursive CTEs are unsupported") 575 expression.args["recursive"] = False 576 return self.with_sql(expression) 577 578 579def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 580 n = self.sql(expression, "this") 581 d = self.sql(expression, "expression") 582 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 583 584 585def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 586 self.unsupported("TABLESAMPLE unsupported") 587 return self.sql(expression.this) 588 589 590def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 591 self.unsupported("PIVOT unsupported") 592 return "" 593 594 595def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 596 return self.cast_sql(expression) 597 598 599def no_comment_column_constraint_sql( 600 self: Generator, expression: exp.CommentColumnConstraint 601) -> str: 602 self.unsupported("CommentColumnConstraint unsupported") 603 return "" 604 605 606def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 607 self.unsupported("MAP_FROM_ENTRIES unsupported") 608 return "" 609 610 611def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 612 this = self.sql(expression, "this") 613 substr = self.sql(expression, "substr") 614 position = self.sql(expression, "position") 615 if position: 616 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 617 return f"STRPOS({this}, {substr})" 618 619 620def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 621 return ( 622 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 623 ) 624 625 626def var_map_sql( 627 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 628) -> str: 629 keys = expression.args["keys"] 630 values = expression.args["values"] 631 632 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 633 self.unsupported("Cannot convert array columns into map.") 634 return self.func(map_func_name, keys, values) 635 636 args = [] 637 for key, value in zip(keys.expressions, values.expressions): 638 args.append(self.sql(key)) 639 args.append(self.sql(value)) 640 641 return self.func(map_func_name, *args) 642 643 644def build_formatted_time( 645 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 646) -> t.Callable[[t.List], E]: 647 """Helper used for time expressions. 648 649 Args: 650 exp_class: the expression class to instantiate. 651 dialect: target sql dialect. 652 default: the default format, True being time. 653 654 Returns: 655 A callable that can be used to return the appropriately formatted time expression. 656 """ 657 658 def _builder(args: t.List): 659 return exp_class( 660 this=seq_get(args, 0), 661 format=Dialect[dialect].format_time( 662 seq_get(args, 1) 663 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 664 ), 665 ) 666 667 return _builder 668 669 670def time_format( 671 dialect: DialectType = None, 672) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 673 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 674 """ 675 Returns the time format for a given expression, unless it's equivalent 676 to the default time format of the dialect of interest. 677 """ 678 time_format = self.format_time(expression) 679 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 680 681 return _time_format 682 683 684def build_date_delta( 685 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 686) -> t.Callable[[t.List], E]: 687 def _builder(args: t.List) -> E: 688 unit_based = len(args) == 3 689 this = args[2] if unit_based else seq_get(args, 0) 690 unit = args[0] if unit_based else exp.Literal.string("DAY") 691 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 692 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 693 694 return _builder 695 696 697def build_date_delta_with_interval( 698 expression_class: t.Type[E], 699) -> t.Callable[[t.List], t.Optional[E]]: 700 def _builder(args: t.List) -> t.Optional[E]: 701 if len(args) < 2: 702 return None 703 704 interval = args[1] 705 706 if not isinstance(interval, exp.Interval): 707 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 708 709 expression = interval.this 710 if expression and expression.is_string: 711 expression = exp.Literal.number(expression.this) 712 713 return expression_class( 714 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 715 ) 716 717 return _builder 718 719 720def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 721 unit = seq_get(args, 0) 722 this = seq_get(args, 1) 723 724 if isinstance(this, exp.Cast) and this.is_type("date"): 725 return exp.DateTrunc(unit=unit, this=this) 726 return exp.TimestampTrunc(this=this, unit=unit) 727 728 729def date_add_interval_sql( 730 data_type: str, kind: str 731) -> t.Callable[[Generator, exp.Expression], str]: 732 def func(self: Generator, expression: exp.Expression) -> str: 733 this = self.sql(expression, "this") 734 unit = expression.args.get("unit") 735 unit = exp.var(unit.name.upper() if unit else "DAY") 736 interval = exp.Interval(this=expression.expression, unit=unit) 737 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 738 739 return func 740 741 742def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 743 return self.func( 744 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 745 ) 746 747 748def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 749 if not expression.expression: 750 from sqlglot.optimizer.annotate_types import annotate_types 751 752 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 753 return self.sql(exp.cast(expression.this, to=target_type)) 754 if expression.text("expression").lower() in TIMEZONES: 755 return self.sql( 756 exp.AtTimeZone( 757 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 758 zone=expression.expression, 759 ) 760 ) 761 return self.func("TIMESTAMP", expression.this, expression.expression) 762 763 764def locate_to_strposition(args: t.List) -> exp.Expression: 765 return exp.StrPosition( 766 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 767 ) 768 769 770def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 771 return self.func( 772 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 773 ) 774 775 776def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 777 return self.sql( 778 exp.Substring( 779 this=expression.this, start=exp.Literal.number(1), length=expression.expression 780 ) 781 ) 782 783 784def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 785 return self.sql( 786 exp.Substring( 787 this=expression.this, 788 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 789 ) 790 ) 791 792 793def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 794 return self.sql(exp.cast(expression.this, "timestamp")) 795 796 797def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 798 return self.sql(exp.cast(expression.this, "date")) 799 800 801# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 802def encode_decode_sql( 803 self: Generator, expression: exp.Expression, name: str, replace: bool = True 804) -> str: 805 charset = expression.args.get("charset") 806 if charset and charset.name.lower() != "utf-8": 807 self.unsupported(f"Expected utf-8 character set, got {charset}.") 808 809 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 810 811 812def min_or_least(self: Generator, expression: exp.Min) -> str: 813 name = "LEAST" if expression.expressions else "MIN" 814 return rename_func(name)(self, expression) 815 816 817def max_or_greatest(self: Generator, expression: exp.Max) -> str: 818 name = "GREATEST" if expression.expressions else "MAX" 819 return rename_func(name)(self, expression) 820 821 822def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 823 cond = expression.this 824 825 if isinstance(expression.this, exp.Distinct): 826 cond = expression.this.expressions[0] 827 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 828 829 return self.func("sum", exp.func("if", cond, 1, 0)) 830 831 832def trim_sql(self: Generator, expression: exp.Trim) -> str: 833 target = self.sql(expression, "this") 834 trim_type = self.sql(expression, "position") 835 remove_chars = self.sql(expression, "expression") 836 collation = self.sql(expression, "collation") 837 838 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 839 if not remove_chars and not collation: 840 return self.trim_sql(expression) 841 842 trim_type = f"{trim_type} " if trim_type else "" 843 remove_chars = f"{remove_chars} " if remove_chars else "" 844 from_part = "FROM " if trim_type or remove_chars else "" 845 collation = f" COLLATE {collation}" if collation else "" 846 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 847 848 849def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 850 return self.func("STRPTIME", expression.this, self.format_time(expression)) 851 852 853def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 854 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 855 856 857def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 858 delim, *rest_args = expression.expressions 859 return self.sql( 860 reduce( 861 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 862 rest_args, 863 ) 864 ) 865 866 867def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 868 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 869 if bad_args: 870 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 871 872 return self.func( 873 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 874 ) 875 876 877def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 878 bad_args = list( 879 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 880 ) 881 if bad_args: 882 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 883 884 return self.func( 885 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 886 ) 887 888 889def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 890 names = [] 891 for agg in aggregations: 892 if isinstance(agg, exp.Alias): 893 names.append(agg.alias) 894 else: 895 """ 896 This case corresponds to aggregations without aliases being used as suffixes 897 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 898 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 899 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 900 """ 901 agg_all_unquoted = agg.transform( 902 lambda node: ( 903 exp.Identifier(this=node.name, quoted=False) 904 if isinstance(node, exp.Identifier) 905 else node 906 ) 907 ) 908 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 909 910 return names 911 912 913def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 914 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 915 916 917# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 918def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 919 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 920 921 922def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 923 return self.func("MAX", expression.this) 924 925 926def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 927 a = self.sql(expression.left) 928 b = self.sql(expression.right) 929 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 930 931 932def is_parse_json(expression: exp.Expression) -> bool: 933 return isinstance(expression, exp.ParseJSON) or ( 934 isinstance(expression, exp.Cast) and expression.is_type("json") 935 ) 936 937 938def isnull_to_is_null(args: t.List) -> exp.Expression: 939 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 940 941 942def generatedasidentitycolumnconstraint_sql( 943 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 944) -> str: 945 start = self.sql(expression, "start") or "1" 946 increment = self.sql(expression, "increment") or "1" 947 return f"IDENTITY({start}, {increment})" 948 949 950def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 951 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 952 if expression.args.get("count"): 953 self.unsupported(f"Only two arguments are supported in function {name}.") 954 955 return self.func(name, expression.this, expression.expression) 956 957 return _arg_max_or_min_sql 958 959 960def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 961 this = expression.this.copy() 962 963 return_type = expression.return_type 964 if return_type.is_type(exp.DataType.Type.DATE): 965 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 966 # can truncate timestamp strings, because some dialects can't cast them to DATE 967 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 968 969 expression.this.replace(exp.cast(this, return_type)) 970 return expression 971 972 973def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 974 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 975 if cast and isinstance(expression, exp.TsOrDsAdd): 976 expression = ts_or_ds_add_cast(expression) 977 978 return self.func( 979 name, 980 exp.var(expression.text("unit").upper() or "DAY"), 981 expression.expression, 982 expression.this, 983 ) 984 985 return _delta_sql 986 987 988def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 989 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 990 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 991 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 992 993 return self.sql(exp.cast(minus_one_day, "date")) 994 995 996def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 997 """Remove table refs from columns in when statements.""" 998 alias = expression.this.args.get("alias") 999 1000 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1001 return self.dialect.normalize_identifier(identifier).name if identifier else None 1002 1003 targets = {normalize(expression.this.this)} 1004 1005 if alias: 1006 targets.add(normalize(alias.this)) 1007 1008 for when in expression.expressions: 1009 when.transform( 1010 lambda node: ( 1011 exp.column(node.this) 1012 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1013 else node 1014 ), 1015 copy=False, 1016 ) 1017 1018 return self.merge_sql(expression) 1019 1020 1021def build_json_extract_path( 1022 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1023) -> t.Callable[[t.List], F]: 1024 def _builder(args: t.List) -> F: 1025 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1026 for arg in args[1:]: 1027 if not isinstance(arg, exp.Literal): 1028 # We use the fallback parser because we can't really transpile non-literals safely 1029 return expr_type.from_arg_list(args) 1030 1031 text = arg.name 1032 if is_int(text): 1033 index = int(text) 1034 segments.append( 1035 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1036 ) 1037 else: 1038 segments.append(exp.JSONPathKey(this=text)) 1039 1040 # This is done to avoid failing in the expression validator due to the arg count 1041 del args[2:] 1042 return expr_type( 1043 this=seq_get(args, 0), 1044 expression=exp.JSONPath(expressions=segments), 1045 only_json_types=arrow_req_json_type, 1046 ) 1047 1048 return _builder 1049 1050 1051def json_extract_segments( 1052 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1053) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1054 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1055 path = expression.expression 1056 if not isinstance(path, exp.JSONPath): 1057 return rename_func(name)(self, expression) 1058 1059 segments = [] 1060 for segment in path.expressions: 1061 path = self.sql(segment) 1062 if path: 1063 if isinstance(segment, exp.JSONPathPart) and ( 1064 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1065 ): 1066 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1067 1068 segments.append(path) 1069 1070 if op: 1071 return f" {op} ".join([self.sql(expression.this), *segments]) 1072 return self.func(name, expression.this, *segments) 1073 1074 return _json_extract_segments 1075 1076 1077def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1078 if isinstance(expression.this, exp.JSONPathWildcard): 1079 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1080 1081 return expression.name 1082 1083 1084def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1085 cond = expression.expression 1086 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1087 alias = cond.expressions[0] 1088 cond = cond.this 1089 elif isinstance(cond, exp.Predicate): 1090 alias = "_u" 1091 else: 1092 self.unsupported("Unsupported filter condition") 1093 return "" 1094 1095 unnest = exp.Unnest(expressions=[expression.this]) 1096 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1097 return self.sql(exp.Array(expressions=[filtered])) 1098 1099 1100def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: 1101 return self.func( 1102 "TO_NUMBER", 1103 expression.this, 1104 expression.args.get("format"), 1105 expression.args.get("nlsparam"), 1106 )
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
168class Dialect(metaclass=_Dialect): 169 INDEX_OFFSET = 0 170 """The base index offset for arrays.""" 171 172 WEEK_OFFSET = 0 173 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 174 175 UNNEST_COLUMN_ONLY = False 176 """Whether `UNNEST` table aliases are treated as column aliases.""" 177 178 ALIAS_POST_TABLESAMPLE = False 179 """Whether the table alias comes after tablesample.""" 180 181 TABLESAMPLE_SIZE_IS_PERCENT = False 182 """Whether a size in the table sample clause represents percentage.""" 183 184 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 185 """Specifies the strategy according to which identifiers should be normalized.""" 186 187 IDENTIFIERS_CAN_START_WITH_DIGIT = False 188 """Whether an unquoted identifier can start with a digit.""" 189 190 DPIPE_IS_STRING_CONCAT = True 191 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 192 193 STRICT_STRING_CONCAT = False 194 """Whether `CONCAT`'s arguments must be strings.""" 195 196 SUPPORTS_USER_DEFINED_TYPES = True 197 """Whether user-defined data types are supported.""" 198 199 SUPPORTS_SEMI_ANTI_JOIN = True 200 """Whether `SEMI` or `ANTI` joins are supported.""" 201 202 NORMALIZE_FUNCTIONS: bool | str = "upper" 203 """ 204 Determines how function names are going to be normalized. 205 Possible values: 206 "upper" or True: Convert names to uppercase. 207 "lower": Convert names to lowercase. 208 False: Disables function name normalization. 209 """ 210 211 LOG_BASE_FIRST: t.Optional[bool] = True 212 """ 213 Whether the base comes first in the `LOG` function. 214 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 215 """ 216 217 NULL_ORDERING = "nulls_are_small" 218 """ 219 Default `NULL` ordering method to use if not explicitly set. 220 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 221 """ 222 223 TYPED_DIVISION = False 224 """ 225 Whether the behavior of `a / b` depends on the types of `a` and `b`. 226 False means `a / b` is always float division. 227 True means `a / b` is integer division if both `a` and `b` are integers. 228 """ 229 230 SAFE_DIVISION = False 231 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 232 233 CONCAT_COALESCE = False 234 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 235 236 DATE_FORMAT = "'%Y-%m-%d'" 237 DATEINT_FORMAT = "'%Y%m%d'" 238 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 239 240 TIME_MAPPING: t.Dict[str, str] = {} 241 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 242 243 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 244 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 245 FORMAT_MAPPING: t.Dict[str, str] = {} 246 """ 247 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 248 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 249 """ 250 251 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 252 """Mapping of an unescaped escape sequence to the corresponding character.""" 253 254 PSEUDOCOLUMNS: t.Set[str] = set() 255 """ 256 Columns that are auto-generated by the engine corresponding to this dialect. 257 For example, such columns may be excluded from `SELECT *` queries. 258 """ 259 260 PREFER_CTE_ALIAS_COLUMN = False 261 """ 262 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 263 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 264 any projection aliases in the subquery. 265 266 For example, 267 WITH y(c) AS ( 268 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 269 ) SELECT c FROM y; 270 271 will be rewritten as 272 273 WITH y(c) AS ( 274 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 275 ) SELECT c FROM y; 276 """ 277 278 # --- Autofilled --- 279 280 tokenizer_class = Tokenizer 281 parser_class = Parser 282 generator_class = Generator 283 284 # A trie of the time_mapping keys 285 TIME_TRIE: t.Dict = {} 286 FORMAT_TRIE: t.Dict = {} 287 288 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 289 INVERSE_TIME_TRIE: t.Dict = {} 290 291 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 292 293 # Delimiters for string literals and identifiers 294 QUOTE_START = "'" 295 QUOTE_END = "'" 296 IDENTIFIER_START = '"' 297 IDENTIFIER_END = '"' 298 299 # Delimiters for bit, hex, byte and unicode literals 300 BIT_START: t.Optional[str] = None 301 BIT_END: t.Optional[str] = None 302 HEX_START: t.Optional[str] = None 303 HEX_END: t.Optional[str] = None 304 BYTE_START: t.Optional[str] = None 305 BYTE_END: t.Optional[str] = None 306 UNICODE_START: t.Optional[str] = None 307 UNICODE_END: t.Optional[str] = None 308 309 @classmethod 310 def get_or_raise(cls, dialect: DialectType) -> Dialect: 311 """ 312 Look up a dialect in the global dialect registry and return it if it exists. 313 314 Args: 315 dialect: The target dialect. If this is a string, it can be optionally followed by 316 additional key-value pairs that are separated by commas and are used to specify 317 dialect settings, such as whether the dialect's identifiers are case-sensitive. 318 319 Example: 320 >>> dialect = dialect_class = get_or_raise("duckdb") 321 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 322 323 Returns: 324 The corresponding Dialect instance. 325 """ 326 327 if not dialect: 328 return cls() 329 if isinstance(dialect, _Dialect): 330 return dialect() 331 if isinstance(dialect, Dialect): 332 return dialect 333 if isinstance(dialect, str): 334 try: 335 dialect_name, *kv_pairs = dialect.split(",") 336 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 337 except ValueError: 338 raise ValueError( 339 f"Invalid dialect format: '{dialect}'. " 340 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 341 ) 342 343 result = cls.get(dialect_name.strip()) 344 if not result: 345 from difflib import get_close_matches 346 347 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 348 if similar: 349 similar = f" Did you mean {similar}?" 350 351 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 352 353 return result(**kwargs) 354 355 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 356 357 @classmethod 358 def format_time( 359 cls, expression: t.Optional[str | exp.Expression] 360 ) -> t.Optional[exp.Expression]: 361 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 362 if isinstance(expression, str): 363 return exp.Literal.string( 364 # the time formats are quoted 365 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 366 ) 367 368 if expression and expression.is_string: 369 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 370 371 return expression 372 373 def __init__(self, **kwargs) -> None: 374 normalization_strategy = kwargs.get("normalization_strategy") 375 376 if normalization_strategy is None: 377 self.normalization_strategy = self.NORMALIZATION_STRATEGY 378 else: 379 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 380 381 def __eq__(self, other: t.Any) -> bool: 382 # Does not currently take dialect state into account 383 return type(self) == other 384 385 def __hash__(self) -> int: 386 # Does not currently take dialect state into account 387 return hash(type(self)) 388 389 def normalize_identifier(self, expression: E) -> E: 390 """ 391 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 392 393 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 394 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 395 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 396 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 397 398 There are also dialects like Spark, which are case-insensitive even when quotes are 399 present, and dialects like MySQL, whose resolution rules match those employed by the 400 underlying operating system, for example they may always be case-sensitive in Linux. 401 402 Finally, the normalization behavior of some engines can even be controlled through flags, 403 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 404 405 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 406 that it can analyze queries in the optimizer and successfully capture their semantics. 407 """ 408 if ( 409 isinstance(expression, exp.Identifier) 410 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 411 and ( 412 not expression.quoted 413 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 414 ) 415 ): 416 expression.set( 417 "this", 418 ( 419 expression.this.upper() 420 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 421 else expression.this.lower() 422 ), 423 ) 424 425 return expression 426 427 def case_sensitive(self, text: str) -> bool: 428 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 429 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 430 return False 431 432 unsafe = ( 433 str.islower 434 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 435 else str.isupper 436 ) 437 return any(unsafe(char) for char in text) 438 439 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 440 """Checks if text can be identified given an identify option. 441 442 Args: 443 text: The text to check. 444 identify: 445 `"always"` or `True`: Always returns `True`. 446 `"safe"`: Only returns `True` if the identifier is case-insensitive. 447 448 Returns: 449 Whether the given text can be identified. 450 """ 451 if identify is True or identify == "always": 452 return True 453 454 if identify == "safe": 455 return not self.case_sensitive(text) 456 457 return False 458 459 def quote_identifier(self, expression: E, identify: bool = True) -> E: 460 """ 461 Adds quotes to a given identifier. 462 463 Args: 464 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 465 identify: If set to `False`, the quotes will only be added if the identifier is deemed 466 "unsafe", with respect to its characters and this dialect's normalization strategy. 467 """ 468 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 469 name = expression.this 470 expression.set( 471 "quoted", 472 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 473 ) 474 475 return expression 476 477 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 478 if isinstance(path, exp.Literal): 479 path_text = path.name 480 if path.is_number: 481 path_text = f"[{path_text}]" 482 483 try: 484 return parse_json_path(path_text) 485 except ParseError as e: 486 logger.warning(f"Invalid JSON path syntax. {str(e)}") 487 488 return path 489 490 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 491 return self.parser(**opts).parse(self.tokenize(sql), sql) 492 493 def parse_into( 494 self, expression_type: exp.IntoType, sql: str, **opts 495 ) -> t.List[t.Optional[exp.Expression]]: 496 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 497 498 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 499 return self.generator(**opts).generate(expression, copy=copy) 500 501 def transpile(self, sql: str, **opts) -> t.List[str]: 502 return [ 503 self.generate(expression, copy=False, **opts) if expression else "" 504 for expression in self.parse(sql) 505 ] 506 507 def tokenize(self, sql: str) -> t.List[Token]: 508 return self.tokenizer.tokenize(sql) 509 510 @property 511 def tokenizer(self) -> Tokenizer: 512 if not hasattr(self, "_tokenizer"): 513 self._tokenizer = self.tokenizer_class(dialect=self) 514 return self._tokenizer 515 516 def parser(self, **opts) -> Parser: 517 return self.parser_class(dialect=self, **opts) 518 519 def generator(self, **opts) -> Generator: 520 return self.generator_class(dialect=self, **opts)
373 def __init__(self, **kwargs) -> None: 374 normalization_strategy = kwargs.get("normalization_strategy") 375 376 if normalization_strategy is None: 377 self.normalization_strategy = self.NORMALIZATION_STRATEGY 378 else: 379 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
309 @classmethod 310 def get_or_raise(cls, dialect: DialectType) -> Dialect: 311 """ 312 Look up a dialect in the global dialect registry and return it if it exists. 313 314 Args: 315 dialect: The target dialect. If this is a string, it can be optionally followed by 316 additional key-value pairs that are separated by commas and are used to specify 317 dialect settings, such as whether the dialect's identifiers are case-sensitive. 318 319 Example: 320 >>> dialect = dialect_class = get_or_raise("duckdb") 321 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 322 323 Returns: 324 The corresponding Dialect instance. 325 """ 326 327 if not dialect: 328 return cls() 329 if isinstance(dialect, _Dialect): 330 return dialect() 331 if isinstance(dialect, Dialect): 332 return dialect 333 if isinstance(dialect, str): 334 try: 335 dialect_name, *kv_pairs = dialect.split(",") 336 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 337 except ValueError: 338 raise ValueError( 339 f"Invalid dialect format: '{dialect}'. " 340 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 341 ) 342 343 result = cls.get(dialect_name.strip()) 344 if not result: 345 from difflib import get_close_matches 346 347 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 348 if similar: 349 similar = f" Did you mean {similar}?" 350 351 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 352 353 return result(**kwargs) 354 355 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
357 @classmethod 358 def format_time( 359 cls, expression: t.Optional[str | exp.Expression] 360 ) -> t.Optional[exp.Expression]: 361 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 362 if isinstance(expression, str): 363 return exp.Literal.string( 364 # the time formats are quoted 365 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 366 ) 367 368 if expression and expression.is_string: 369 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 370 371 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
389 def normalize_identifier(self, expression: E) -> E: 390 """ 391 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 392 393 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 394 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 395 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 396 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 397 398 There are also dialects like Spark, which are case-insensitive even when quotes are 399 present, and dialects like MySQL, whose resolution rules match those employed by the 400 underlying operating system, for example they may always be case-sensitive in Linux. 401 402 Finally, the normalization behavior of some engines can even be controlled through flags, 403 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 404 405 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 406 that it can analyze queries in the optimizer and successfully capture their semantics. 407 """ 408 if ( 409 isinstance(expression, exp.Identifier) 410 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 411 and ( 412 not expression.quoted 413 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 414 ) 415 ): 416 expression.set( 417 "this", 418 ( 419 expression.this.upper() 420 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 421 else expression.this.lower() 422 ), 423 ) 424 425 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
427 def case_sensitive(self, text: str) -> bool: 428 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 429 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 430 return False 431 432 unsafe = ( 433 str.islower 434 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 435 else str.isupper 436 ) 437 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
439 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 440 """Checks if text can be identified given an identify option. 441 442 Args: 443 text: The text to check. 444 identify: 445 `"always"` or `True`: Always returns `True`. 446 `"safe"`: Only returns `True` if the identifier is case-insensitive. 447 448 Returns: 449 Whether the given text can be identified. 450 """ 451 if identify is True or identify == "always": 452 return True 453 454 if identify == "safe": 455 return not self.case_sensitive(text) 456 457 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
459 def quote_identifier(self, expression: E, identify: bool = True) -> E: 460 """ 461 Adds quotes to a given identifier. 462 463 Args: 464 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 465 identify: If set to `False`, the quotes will only be added if the identifier is deemed 466 "unsafe", with respect to its characters and this dialect's normalization strategy. 467 """ 468 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 469 name = expression.this 470 expression.set( 471 "quoted", 472 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 473 ) 474 475 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
477 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 478 if isinstance(path, exp.Literal): 479 path_text = path.name 480 if path.is_number: 481 path_text = f"[{path_text}]" 482 483 try: 484 return parse_json_path(path_text) 485 except ParseError as e: 486 logger.warning(f"Invalid JSON path syntax. {str(e)}") 487 488 return path
536def if_sql( 537 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 538) -> t.Callable[[Generator, exp.If], str]: 539 def _if_sql(self: Generator, expression: exp.If) -> str: 540 return self.func( 541 name, 542 expression.this, 543 expression.args.get("true"), 544 expression.args.get("false") or false_value, 545 ) 546 547 return _if_sql
550def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 551 this = expression.this 552 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 553 this.replace(exp.cast(this, "json")) 554 555 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
612def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 613 this = self.sql(expression, "this") 614 substr = self.sql(expression, "substr") 615 position = self.sql(expression, "position") 616 if position: 617 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 618 return f"STRPOS({this}, {substr})"
627def var_map_sql( 628 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 629) -> str: 630 keys = expression.args["keys"] 631 values = expression.args["values"] 632 633 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 634 self.unsupported("Cannot convert array columns into map.") 635 return self.func(map_func_name, keys, values) 636 637 args = [] 638 for key, value in zip(keys.expressions, values.expressions): 639 args.append(self.sql(key)) 640 args.append(self.sql(value)) 641 642 return self.func(map_func_name, *args)
645def build_formatted_time( 646 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 647) -> t.Callable[[t.List], E]: 648 """Helper used for time expressions. 649 650 Args: 651 exp_class: the expression class to instantiate. 652 dialect: target sql dialect. 653 default: the default format, True being time. 654 655 Returns: 656 A callable that can be used to return the appropriately formatted time expression. 657 """ 658 659 def _builder(args: t.List): 660 return exp_class( 661 this=seq_get(args, 0), 662 format=Dialect[dialect].format_time( 663 seq_get(args, 1) 664 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 665 ), 666 ) 667 668 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
671def time_format( 672 dialect: DialectType = None, 673) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 674 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 675 """ 676 Returns the time format for a given expression, unless it's equivalent 677 to the default time format of the dialect of interest. 678 """ 679 time_format = self.format_time(expression) 680 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 681 682 return _time_format
685def build_date_delta( 686 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 687) -> t.Callable[[t.List], E]: 688 def _builder(args: t.List) -> E: 689 unit_based = len(args) == 3 690 this = args[2] if unit_based else seq_get(args, 0) 691 unit = args[0] if unit_based else exp.Literal.string("DAY") 692 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 693 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 694 695 return _builder
698def build_date_delta_with_interval( 699 expression_class: t.Type[E], 700) -> t.Callable[[t.List], t.Optional[E]]: 701 def _builder(args: t.List) -> t.Optional[E]: 702 if len(args) < 2: 703 return None 704 705 interval = args[1] 706 707 if not isinstance(interval, exp.Interval): 708 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 709 710 expression = interval.this 711 if expression and expression.is_string: 712 expression = exp.Literal.number(expression.this) 713 714 return expression_class( 715 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 716 ) 717 718 return _builder
730def date_add_interval_sql( 731 data_type: str, kind: str 732) -> t.Callable[[Generator, exp.Expression], str]: 733 def func(self: Generator, expression: exp.Expression) -> str: 734 this = self.sql(expression, "this") 735 unit = expression.args.get("unit") 736 unit = exp.var(unit.name.upper() if unit else "DAY") 737 interval = exp.Interval(this=expression.expression, unit=unit) 738 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 739 740 return func
749def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 750 if not expression.expression: 751 from sqlglot.optimizer.annotate_types import annotate_types 752 753 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 754 return self.sql(exp.cast(expression.this, to=target_type)) 755 if expression.text("expression").lower() in TIMEZONES: 756 return self.sql( 757 exp.AtTimeZone( 758 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 759 zone=expression.expression, 760 ) 761 ) 762 return self.func("TIMESTAMP", expression.this, expression.expression)
803def encode_decode_sql( 804 self: Generator, expression: exp.Expression, name: str, replace: bool = True 805) -> str: 806 charset = expression.args.get("charset") 807 if charset and charset.name.lower() != "utf-8": 808 self.unsupported(f"Expected utf-8 character set, got {charset}.") 809 810 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
823def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 824 cond = expression.this 825 826 if isinstance(expression.this, exp.Distinct): 827 cond = expression.this.expressions[0] 828 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 829 830 return self.func("sum", exp.func("if", cond, 1, 0))
833def trim_sql(self: Generator, expression: exp.Trim) -> str: 834 target = self.sql(expression, "this") 835 trim_type = self.sql(expression, "position") 836 remove_chars = self.sql(expression, "expression") 837 collation = self.sql(expression, "collation") 838 839 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 840 if not remove_chars and not collation: 841 return self.trim_sql(expression) 842 843 trim_type = f"{trim_type} " if trim_type else "" 844 remove_chars = f"{remove_chars} " if remove_chars else "" 845 from_part = "FROM " if trim_type or remove_chars else "" 846 collation = f" COLLATE {collation}" if collation else "" 847 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
868def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 869 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 870 if bad_args: 871 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 872 873 return self.func( 874 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 875 )
878def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 879 bad_args = list( 880 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 881 ) 882 if bad_args: 883 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 884 885 return self.func( 886 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 887 )
890def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 891 names = [] 892 for agg in aggregations: 893 if isinstance(agg, exp.Alias): 894 names.append(agg.alias) 895 else: 896 """ 897 This case corresponds to aggregations without aliases being used as suffixes 898 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 899 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 900 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 901 """ 902 agg_all_unquoted = agg.transform( 903 lambda node: ( 904 exp.Identifier(this=node.name, quoted=False) 905 if isinstance(node, exp.Identifier) 906 else node 907 ) 908 ) 909 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 910 911 return names
951def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 952 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 953 if expression.args.get("count"): 954 self.unsupported(f"Only two arguments are supported in function {name}.") 955 956 return self.func(name, expression.this, expression.expression) 957 958 return _arg_max_or_min_sql
961def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 962 this = expression.this.copy() 963 964 return_type = expression.return_type 965 if return_type.is_type(exp.DataType.Type.DATE): 966 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 967 # can truncate timestamp strings, because some dialects can't cast them to DATE 968 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 969 970 expression.this.replace(exp.cast(this, return_type)) 971 return expression
974def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 975 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 976 if cast and isinstance(expression, exp.TsOrDsAdd): 977 expression = ts_or_ds_add_cast(expression) 978 979 return self.func( 980 name, 981 exp.var(expression.text("unit").upper() or "DAY"), 982 expression.expression, 983 expression.this, 984 ) 985 986 return _delta_sql
989def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 990 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 991 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 992 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 993 994 return self.sql(exp.cast(minus_one_day, "date"))
997def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 998 """Remove table refs from columns in when statements.""" 999 alias = expression.this.args.get("alias") 1000 1001 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1002 return self.dialect.normalize_identifier(identifier).name if identifier else None 1003 1004 targets = {normalize(expression.this.this)} 1005 1006 if alias: 1007 targets.add(normalize(alias.this)) 1008 1009 for when in expression.expressions: 1010 when.transform( 1011 lambda node: ( 1012 exp.column(node.this) 1013 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1014 else node 1015 ), 1016 copy=False, 1017 ) 1018 1019 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1022def build_json_extract_path( 1023 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1024) -> t.Callable[[t.List], F]: 1025 def _builder(args: t.List) -> F: 1026 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1027 for arg in args[1:]: 1028 if not isinstance(arg, exp.Literal): 1029 # We use the fallback parser because we can't really transpile non-literals safely 1030 return expr_type.from_arg_list(args) 1031 1032 text = arg.name 1033 if is_int(text): 1034 index = int(text) 1035 segments.append( 1036 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1037 ) 1038 else: 1039 segments.append(exp.JSONPathKey(this=text)) 1040 1041 # This is done to avoid failing in the expression validator due to the arg count 1042 del args[2:] 1043 return expr_type( 1044 this=seq_get(args, 0), 1045 expression=exp.JSONPath(expressions=segments), 1046 only_json_types=arrow_req_json_type, 1047 ) 1048 1049 return _builder
1052def json_extract_segments( 1053 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1054) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1055 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1056 path = expression.expression 1057 if not isinstance(path, exp.JSONPath): 1058 return rename_func(name)(self, expression) 1059 1060 segments = [] 1061 for segment in path.expressions: 1062 path = self.sql(segment) 1063 if path: 1064 if isinstance(segment, exp.JSONPathPart) and ( 1065 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1066 ): 1067 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1068 1069 segments.append(path) 1070 1071 if op: 1072 return f" {op} ".join([self.sql(expression.this), *segments]) 1073 return self.func(name, expression.this, *segments) 1074 1075 return _json_extract_segments
1085def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1086 cond = expression.expression 1087 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1088 alias = cond.expressions[0] 1089 cond = cond.this 1090 elif isinstance(cond, exp.Predicate): 1091 alias = "_u" 1092 else: 1093 self.unsupported("Unsupported filter condition") 1094 return "" 1095 1096 unnest = exp.Unnest(expressions=[expression.this]) 1097 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1098 return self.sql(exp.Array(expressions=[filtered]))