sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28UNESCAPED_SEQUENCES = { 29 "\\a": "\a", 30 "\\b": "\b", 31 "\\f": "\f", 32 "\\n": "\n", 33 "\\r": "\r", 34 "\\t": "\t", 35 "\\v": "\v", 36 "\\\\": "\\", 37} 38 39 40class Dialects(str, Enum): 41 """Dialects supported by SQLGLot.""" 42 43 DIALECT = "" 44 45 ATHENA = "athena" 46 BIGQUERY = "bigquery" 47 CLICKHOUSE = "clickhouse" 48 DATABRICKS = "databricks" 49 DORIS = "doris" 50 DRILL = "drill" 51 DUCKDB = "duckdb" 52 HIVE = "hive" 53 MATERIALIZE = "materialize" 54 MYSQL = "mysql" 55 ORACLE = "oracle" 56 POSTGRES = "postgres" 57 PRESTO = "presto" 58 PRQL = "prql" 59 REDSHIFT = "redshift" 60 RISINGWAVE = "risingwave" 61 SNOWFLAKE = "snowflake" 62 SPARK = "spark" 63 SPARK2 = "spark2" 64 SQLITE = "sqlite" 65 STARROCKS = "starrocks" 66 TABLEAU = "tableau" 67 TERADATA = "teradata" 68 TRINO = "trino" 69 TSQL = "tsql" 70 71 72class NormalizationStrategy(str, AutoName): 73 """Specifies the strategy according to which identifiers should be normalized.""" 74 75 LOWERCASE = auto() 76 """Unquoted identifiers are lowercased.""" 77 78 UPPERCASE = auto() 79 """Unquoted identifiers are uppercased.""" 80 81 CASE_SENSITIVE = auto() 82 """Always case-sensitive, regardless of quotes.""" 83 84 CASE_INSENSITIVE = auto() 85 """Always case-insensitive, regardless of quotes.""" 86 87 88class _Dialect(type): 89 classes: t.Dict[str, t.Type[Dialect]] = {} 90 91 def __eq__(cls, other: t.Any) -> bool: 92 if cls is other: 93 return True 94 if isinstance(other, str): 95 return cls is cls.get(other) 96 if isinstance(other, Dialect): 97 return cls is type(other) 98 99 return False 100 101 def __hash__(cls) -> int: 102 return hash(cls.__name__.lower()) 103 104 @classmethod 105 def __getitem__(cls, key: str) -> t.Type[Dialect]: 106 return cls.classes[key] 107 108 @classmethod 109 def get( 110 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 111 ) -> t.Optional[t.Type[Dialect]]: 112 return cls.classes.get(key, default) 113 114 def __new__(cls, clsname, bases, attrs): 115 klass = super().__new__(cls, clsname, bases, attrs) 116 enum = Dialects.__members__.get(clsname.upper()) 117 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 118 119 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 120 klass.FORMAT_TRIE = ( 121 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 122 ) 123 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 124 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 125 126 base = seq_get(bases, 0) 127 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 128 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 129 base_parser = (getattr(base, "parser_class", Parser),) 130 base_generator = (getattr(base, "generator_class", Generator),) 131 132 klass.tokenizer_class = klass.__dict__.get( 133 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 134 ) 135 klass.jsonpath_tokenizer_class = klass.__dict__.get( 136 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 137 ) 138 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 139 klass.generator_class = klass.__dict__.get( 140 "Generator", type("Generator", base_generator, {}) 141 ) 142 143 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 144 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 145 klass.tokenizer_class._IDENTIFIERS.items() 146 )[0] 147 148 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 149 return next( 150 ( 151 (s, e) 152 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 153 if t == token_type 154 ), 155 (None, None), 156 ) 157 158 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 159 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 160 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 161 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 162 163 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 164 klass.UNESCAPED_SEQUENCES = { 165 **UNESCAPED_SEQUENCES, 166 **klass.UNESCAPED_SEQUENCES, 167 } 168 169 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 170 171 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 172 173 if enum not in ("", "bigquery"): 174 klass.generator_class.SELECT_KINDS = () 175 176 if enum not in ("", "athena", "presto", "trino"): 177 klass.generator_class.TRY_SUPPORTED = False 178 klass.generator_class.SUPPORTS_UESCAPE = False 179 180 if enum not in ("", "databricks", "hive", "spark", "spark2"): 181 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 182 for modifier in ("cluster", "distribute", "sort"): 183 modifier_transforms.pop(modifier, None) 184 185 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 186 187 if enum not in ("", "doris", "mysql"): 188 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 189 TokenType.STRAIGHT_JOIN, 190 } 191 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 192 TokenType.STRAIGHT_JOIN, 193 } 194 195 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 196 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 197 TokenType.ANTI, 198 TokenType.SEMI, 199 } 200 201 return klass 202 203 204class Dialect(metaclass=_Dialect): 205 INDEX_OFFSET = 0 206 """The base index offset for arrays.""" 207 208 WEEK_OFFSET = 0 209 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 210 211 UNNEST_COLUMN_ONLY = False 212 """Whether `UNNEST` table aliases are treated as column aliases.""" 213 214 ALIAS_POST_TABLESAMPLE = False 215 """Whether the table alias comes after tablesample.""" 216 217 TABLESAMPLE_SIZE_IS_PERCENT = False 218 """Whether a size in the table sample clause represents percentage.""" 219 220 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 221 """Specifies the strategy according to which identifiers should be normalized.""" 222 223 IDENTIFIERS_CAN_START_WITH_DIGIT = False 224 """Whether an unquoted identifier can start with a digit.""" 225 226 DPIPE_IS_STRING_CONCAT = True 227 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 228 229 STRICT_STRING_CONCAT = False 230 """Whether `CONCAT`'s arguments must be strings.""" 231 232 SUPPORTS_USER_DEFINED_TYPES = True 233 """Whether user-defined data types are supported.""" 234 235 SUPPORTS_SEMI_ANTI_JOIN = True 236 """Whether `SEMI` or `ANTI` joins are supported.""" 237 238 SUPPORTS_COLUMN_JOIN_MARKS = False 239 """Whether the old-style outer join (+) syntax is supported.""" 240 241 COPY_PARAMS_ARE_CSV = True 242 """Separator of COPY statement parameters.""" 243 244 NORMALIZE_FUNCTIONS: bool | str = "upper" 245 """ 246 Determines how function names are going to be normalized. 247 Possible values: 248 "upper" or True: Convert names to uppercase. 249 "lower": Convert names to lowercase. 250 False: Disables function name normalization. 251 """ 252 253 LOG_BASE_FIRST: t.Optional[bool] = True 254 """ 255 Whether the base comes first in the `LOG` function. 256 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 257 """ 258 259 NULL_ORDERING = "nulls_are_small" 260 """ 261 Default `NULL` ordering method to use if not explicitly set. 262 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 263 """ 264 265 TYPED_DIVISION = False 266 """ 267 Whether the behavior of `a / b` depends on the types of `a` and `b`. 268 False means `a / b` is always float division. 269 True means `a / b` is integer division if both `a` and `b` are integers. 270 """ 271 272 SAFE_DIVISION = False 273 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 274 275 CONCAT_COALESCE = False 276 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 277 278 HEX_LOWERCASE = False 279 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 280 281 DATE_FORMAT = "'%Y-%m-%d'" 282 DATEINT_FORMAT = "'%Y%m%d'" 283 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 284 285 TIME_MAPPING: t.Dict[str, str] = {} 286 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 287 288 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 289 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 290 FORMAT_MAPPING: t.Dict[str, str] = {} 291 """ 292 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 293 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 294 """ 295 296 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 297 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 298 299 PSEUDOCOLUMNS: t.Set[str] = set() 300 """ 301 Columns that are auto-generated by the engine corresponding to this dialect. 302 For example, such columns may be excluded from `SELECT *` queries. 303 """ 304 305 PREFER_CTE_ALIAS_COLUMN = False 306 """ 307 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 308 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 309 any projection aliases in the subquery. 310 311 For example, 312 WITH y(c) AS ( 313 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 314 ) SELECT c FROM y; 315 316 will be rewritten as 317 318 WITH y(c) AS ( 319 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 320 ) SELECT c FROM y; 321 """ 322 323 COPY_PARAMS_ARE_CSV = True 324 """ 325 Whether COPY statement parameters are separated by comma or whitespace 326 """ 327 328 # --- Autofilled --- 329 330 tokenizer_class = Tokenizer 331 jsonpath_tokenizer_class = JSONPathTokenizer 332 parser_class = Parser 333 generator_class = Generator 334 335 # A trie of the time_mapping keys 336 TIME_TRIE: t.Dict = {} 337 FORMAT_TRIE: t.Dict = {} 338 339 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 340 INVERSE_TIME_TRIE: t.Dict = {} 341 342 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 343 344 # Delimiters for string literals and identifiers 345 QUOTE_START = "'" 346 QUOTE_END = "'" 347 IDENTIFIER_START = '"' 348 IDENTIFIER_END = '"' 349 350 # Delimiters for bit, hex, byte and unicode literals 351 BIT_START: t.Optional[str] = None 352 BIT_END: t.Optional[str] = None 353 HEX_START: t.Optional[str] = None 354 HEX_END: t.Optional[str] = None 355 BYTE_START: t.Optional[str] = None 356 BYTE_END: t.Optional[str] = None 357 UNICODE_START: t.Optional[str] = None 358 UNICODE_END: t.Optional[str] = None 359 360 DATE_PART_MAPPING = { 361 "Y": "YEAR", 362 "YY": "YEAR", 363 "YYY": "YEAR", 364 "YYYY": "YEAR", 365 "YR": "YEAR", 366 "YEARS": "YEAR", 367 "YRS": "YEAR", 368 "MM": "MONTH", 369 "MON": "MONTH", 370 "MONS": "MONTH", 371 "MONTHS": "MONTH", 372 "D": "DAY", 373 "DD": "DAY", 374 "DAYS": "DAY", 375 "DAYOFMONTH": "DAY", 376 "DAY OF WEEK": "DAYOFWEEK", 377 "WEEKDAY": "DAYOFWEEK", 378 "DOW": "DAYOFWEEK", 379 "DW": "DAYOFWEEK", 380 "WEEKDAY_ISO": "DAYOFWEEKISO", 381 "DOW_ISO": "DAYOFWEEKISO", 382 "DW_ISO": "DAYOFWEEKISO", 383 "DAY OF YEAR": "DAYOFYEAR", 384 "DOY": "DAYOFYEAR", 385 "DY": "DAYOFYEAR", 386 "W": "WEEK", 387 "WK": "WEEK", 388 "WEEKOFYEAR": "WEEK", 389 "WOY": "WEEK", 390 "WY": "WEEK", 391 "WEEK_ISO": "WEEKISO", 392 "WEEKOFYEARISO": "WEEKISO", 393 "WEEKOFYEAR_ISO": "WEEKISO", 394 "Q": "QUARTER", 395 "QTR": "QUARTER", 396 "QTRS": "QUARTER", 397 "QUARTERS": "QUARTER", 398 "H": "HOUR", 399 "HH": "HOUR", 400 "HR": "HOUR", 401 "HOURS": "HOUR", 402 "HRS": "HOUR", 403 "M": "MINUTE", 404 "MI": "MINUTE", 405 "MIN": "MINUTE", 406 "MINUTES": "MINUTE", 407 "MINS": "MINUTE", 408 "S": "SECOND", 409 "SEC": "SECOND", 410 "SECONDS": "SECOND", 411 "SECS": "SECOND", 412 "MS": "MILLISECOND", 413 "MSEC": "MILLISECOND", 414 "MSECS": "MILLISECOND", 415 "MSECOND": "MILLISECOND", 416 "MSECONDS": "MILLISECOND", 417 "MILLISEC": "MILLISECOND", 418 "MILLISECS": "MILLISECOND", 419 "MILLISECON": "MILLISECOND", 420 "MILLISECONDS": "MILLISECOND", 421 "US": "MICROSECOND", 422 "USEC": "MICROSECOND", 423 "USECS": "MICROSECOND", 424 "MICROSEC": "MICROSECOND", 425 "MICROSECS": "MICROSECOND", 426 "USECOND": "MICROSECOND", 427 "USECONDS": "MICROSECOND", 428 "MICROSECONDS": "MICROSECOND", 429 "NS": "NANOSECOND", 430 "NSEC": "NANOSECOND", 431 "NANOSEC": "NANOSECOND", 432 "NSECOND": "NANOSECOND", 433 "NSECONDS": "NANOSECOND", 434 "NANOSECS": "NANOSECOND", 435 "EPOCH_SECOND": "EPOCH", 436 "EPOCH_SECONDS": "EPOCH", 437 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 438 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 439 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 440 "TZH": "TIMEZONE_HOUR", 441 "TZM": "TIMEZONE_MINUTE", 442 "DEC": "DECADE", 443 "DECS": "DECADE", 444 "DECADES": "DECADE", 445 "MIL": "MILLENIUM", 446 "MILS": "MILLENIUM", 447 "MILLENIA": "MILLENIUM", 448 "C": "CENTURY", 449 "CENT": "CENTURY", 450 "CENTS": "CENTURY", 451 "CENTURIES": "CENTURY", 452 } 453 454 @classmethod 455 def get_or_raise(cls, dialect: DialectType) -> Dialect: 456 """ 457 Look up a dialect in the global dialect registry and return it if it exists. 458 459 Args: 460 dialect: The target dialect. If this is a string, it can be optionally followed by 461 additional key-value pairs that are separated by commas and are used to specify 462 dialect settings, such as whether the dialect's identifiers are case-sensitive. 463 464 Example: 465 >>> dialect = dialect_class = get_or_raise("duckdb") 466 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 467 468 Returns: 469 The corresponding Dialect instance. 470 """ 471 472 if not dialect: 473 return cls() 474 if isinstance(dialect, _Dialect): 475 return dialect() 476 if isinstance(dialect, Dialect): 477 return dialect 478 if isinstance(dialect, str): 479 try: 480 dialect_name, *kv_pairs = dialect.split(",") 481 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 482 except ValueError: 483 raise ValueError( 484 f"Invalid dialect format: '{dialect}'. " 485 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 486 ) 487 488 result = cls.get(dialect_name.strip()) 489 if not result: 490 from difflib import get_close_matches 491 492 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 493 if similar: 494 similar = f" Did you mean {similar}?" 495 496 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 497 498 return result(**kwargs) 499 500 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 501 502 @classmethod 503 def format_time( 504 cls, expression: t.Optional[str | exp.Expression] 505 ) -> t.Optional[exp.Expression]: 506 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 507 if isinstance(expression, str): 508 return exp.Literal.string( 509 # the time formats are quoted 510 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 511 ) 512 513 if expression and expression.is_string: 514 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 515 516 return expression 517 518 def __init__(self, **kwargs) -> None: 519 normalization_strategy = kwargs.pop("normalization_strategy", None) 520 521 if normalization_strategy is None: 522 self.normalization_strategy = self.NORMALIZATION_STRATEGY 523 else: 524 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 525 526 self.settings = kwargs 527 528 def __eq__(self, other: t.Any) -> bool: 529 # Does not currently take dialect state into account 530 return type(self) == other 531 532 def __hash__(self) -> int: 533 # Does not currently take dialect state into account 534 return hash(type(self)) 535 536 def normalize_identifier(self, expression: E) -> E: 537 """ 538 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 539 540 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 541 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 542 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 543 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 544 545 There are also dialects like Spark, which are case-insensitive even when quotes are 546 present, and dialects like MySQL, whose resolution rules match those employed by the 547 underlying operating system, for example they may always be case-sensitive in Linux. 548 549 Finally, the normalization behavior of some engines can even be controlled through flags, 550 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 551 552 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 553 that it can analyze queries in the optimizer and successfully capture their semantics. 554 """ 555 if ( 556 isinstance(expression, exp.Identifier) 557 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 558 and ( 559 not expression.quoted 560 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 561 ) 562 ): 563 expression.set( 564 "this", 565 ( 566 expression.this.upper() 567 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 568 else expression.this.lower() 569 ), 570 ) 571 572 return expression 573 574 def case_sensitive(self, text: str) -> bool: 575 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 576 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 577 return False 578 579 unsafe = ( 580 str.islower 581 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 582 else str.isupper 583 ) 584 return any(unsafe(char) for char in text) 585 586 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 587 """Checks if text can be identified given an identify option. 588 589 Args: 590 text: The text to check. 591 identify: 592 `"always"` or `True`: Always returns `True`. 593 `"safe"`: Only returns `True` if the identifier is case-insensitive. 594 595 Returns: 596 Whether the given text can be identified. 597 """ 598 if identify is True or identify == "always": 599 return True 600 601 if identify == "safe": 602 return not self.case_sensitive(text) 603 604 return False 605 606 def quote_identifier(self, expression: E, identify: bool = True) -> E: 607 """ 608 Adds quotes to a given identifier. 609 610 Args: 611 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 612 identify: If set to `False`, the quotes will only be added if the identifier is deemed 613 "unsafe", with respect to its characters and this dialect's normalization strategy. 614 """ 615 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 616 name = expression.this 617 expression.set( 618 "quoted", 619 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 620 ) 621 622 return expression 623 624 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 625 if isinstance(path, exp.Literal): 626 path_text = path.name 627 if path.is_number: 628 path_text = f"[{path_text}]" 629 try: 630 return parse_json_path(path_text, self) 631 except ParseError as e: 632 logger.warning(f"Invalid JSON path syntax. {str(e)}") 633 634 return path 635 636 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 637 return self.parser(**opts).parse(self.tokenize(sql), sql) 638 639 def parse_into( 640 self, expression_type: exp.IntoType, sql: str, **opts 641 ) -> t.List[t.Optional[exp.Expression]]: 642 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 643 644 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 645 return self.generator(**opts).generate(expression, copy=copy) 646 647 def transpile(self, sql: str, **opts) -> t.List[str]: 648 return [ 649 self.generate(expression, copy=False, **opts) if expression else "" 650 for expression in self.parse(sql) 651 ] 652 653 def tokenize(self, sql: str) -> t.List[Token]: 654 return self.tokenizer.tokenize(sql) 655 656 @property 657 def tokenizer(self) -> Tokenizer: 658 if not hasattr(self, "_tokenizer"): 659 self._tokenizer = self.tokenizer_class(dialect=self) 660 return self._tokenizer 661 662 @property 663 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 664 if not hasattr(self, "_jsonpath_tokenizer"): 665 self._jsonpath_tokenizer = self.jsonpath_tokenizer_class(dialect=self) 666 return self._jsonpath_tokenizer 667 668 def parser(self, **opts) -> Parser: 669 return self.parser_class(dialect=self, **opts) 670 671 def generator(self, **opts) -> Generator: 672 return self.generator_class(dialect=self, **opts) 673 674 675DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 676 677 678def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 679 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 680 681 682def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 683 if expression.args.get("accuracy"): 684 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 685 return self.func("APPROX_COUNT_DISTINCT", expression.this) 686 687 688def if_sql( 689 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 690) -> t.Callable[[Generator, exp.If], str]: 691 def _if_sql(self: Generator, expression: exp.If) -> str: 692 return self.func( 693 name, 694 expression.this, 695 expression.args.get("true"), 696 expression.args.get("false") or false_value, 697 ) 698 699 return _if_sql 700 701 702def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 703 this = expression.this 704 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 705 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 706 707 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 708 709 710def inline_array_sql(self: Generator, expression: exp.Array) -> str: 711 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 712 713 714def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 715 elem = seq_get(expression.expressions, 0) 716 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 717 return self.func("ARRAY", elem) 718 return inline_array_sql(self, expression) 719 720 721def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 722 return self.like_sql( 723 exp.Like( 724 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 725 ) 726 ) 727 728 729def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 730 zone = self.sql(expression, "this") 731 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 732 733 734def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 735 if expression.args.get("recursive"): 736 self.unsupported("Recursive CTEs are unsupported") 737 expression.args["recursive"] = False 738 return self.with_sql(expression) 739 740 741def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 742 n = self.sql(expression, "this") 743 d = self.sql(expression, "expression") 744 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 745 746 747def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 748 self.unsupported("TABLESAMPLE unsupported") 749 return self.sql(expression.this) 750 751 752def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 753 self.unsupported("PIVOT unsupported") 754 return "" 755 756 757def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 758 return self.cast_sql(expression) 759 760 761def no_comment_column_constraint_sql( 762 self: Generator, expression: exp.CommentColumnConstraint 763) -> str: 764 self.unsupported("CommentColumnConstraint unsupported") 765 return "" 766 767 768def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 769 self.unsupported("MAP_FROM_ENTRIES unsupported") 770 return "" 771 772 773def str_position_sql( 774 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 775) -> str: 776 this = self.sql(expression, "this") 777 substr = self.sql(expression, "substr") 778 position = self.sql(expression, "position") 779 instance = expression.args.get("instance") if generate_instance else None 780 position_offset = "" 781 782 if position: 783 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 784 this = self.func("SUBSTR", this, position) 785 position_offset = f" + {position} - 1" 786 787 return self.func("STRPOS", this, substr, instance) + position_offset 788 789 790def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 791 return ( 792 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 793 ) 794 795 796def var_map_sql( 797 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 798) -> str: 799 keys = expression.args["keys"] 800 values = expression.args["values"] 801 802 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 803 self.unsupported("Cannot convert array columns into map.") 804 return self.func(map_func_name, keys, values) 805 806 args = [] 807 for key, value in zip(keys.expressions, values.expressions): 808 args.append(self.sql(key)) 809 args.append(self.sql(value)) 810 811 return self.func(map_func_name, *args) 812 813 814def build_formatted_time( 815 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 816) -> t.Callable[[t.List], E]: 817 """Helper used for time expressions. 818 819 Args: 820 exp_class: the expression class to instantiate. 821 dialect: target sql dialect. 822 default: the default format, True being time. 823 824 Returns: 825 A callable that can be used to return the appropriately formatted time expression. 826 """ 827 828 def _builder(args: t.List): 829 return exp_class( 830 this=seq_get(args, 0), 831 format=Dialect[dialect].format_time( 832 seq_get(args, 1) 833 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 834 ), 835 ) 836 837 return _builder 838 839 840def time_format( 841 dialect: DialectType = None, 842) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 843 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 844 """ 845 Returns the time format for a given expression, unless it's equivalent 846 to the default time format of the dialect of interest. 847 """ 848 time_format = self.format_time(expression) 849 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 850 851 return _time_format 852 853 854def build_date_delta( 855 exp_class: t.Type[E], 856 unit_mapping: t.Optional[t.Dict[str, str]] = None, 857 default_unit: t.Optional[str] = "DAY", 858) -> t.Callable[[t.List], E]: 859 def _builder(args: t.List) -> E: 860 unit_based = len(args) == 3 861 this = args[2] if unit_based else seq_get(args, 0) 862 unit = None 863 if unit_based or default_unit: 864 unit = args[0] if unit_based else exp.Literal.string(default_unit) 865 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 866 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 867 868 return _builder 869 870 871def build_date_delta_with_interval( 872 expression_class: t.Type[E], 873) -> t.Callable[[t.List], t.Optional[E]]: 874 def _builder(args: t.List) -> t.Optional[E]: 875 if len(args) < 2: 876 return None 877 878 interval = args[1] 879 880 if not isinstance(interval, exp.Interval): 881 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 882 883 expression = interval.this 884 if expression and expression.is_string: 885 expression = exp.Literal.number(expression.this) 886 887 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 888 889 return _builder 890 891 892def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 893 unit = seq_get(args, 0) 894 this = seq_get(args, 1) 895 896 if isinstance(this, exp.Cast) and this.is_type("date"): 897 return exp.DateTrunc(unit=unit, this=this) 898 return exp.TimestampTrunc(this=this, unit=unit) 899 900 901def date_add_interval_sql( 902 data_type: str, kind: str 903) -> t.Callable[[Generator, exp.Expression], str]: 904 def func(self: Generator, expression: exp.Expression) -> str: 905 this = self.sql(expression, "this") 906 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 907 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 908 909 return func 910 911 912def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 913 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 914 args = [unit_to_str(expression), expression.this] 915 if zone: 916 args.append(expression.args.get("zone")) 917 return self.func("DATE_TRUNC", *args) 918 919 return _timestamptrunc_sql 920 921 922def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 923 zone = expression.args.get("zone") 924 if not zone: 925 from sqlglot.optimizer.annotate_types import annotate_types 926 927 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 928 return self.sql(exp.cast(expression.this, target_type)) 929 if zone.name.lower() in TIMEZONES: 930 return self.sql( 931 exp.AtTimeZone( 932 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 933 zone=zone, 934 ) 935 ) 936 return self.func("TIMESTAMP", expression.this, zone) 937 938 939def no_time_sql(self: Generator, expression: exp.Time) -> str: 940 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 941 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 942 expr = exp.cast( 943 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 944 ) 945 return self.sql(expr) 946 947 948def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 949 this = expression.this 950 expr = expression.expression 951 952 if expr.name.lower() in TIMEZONES: 953 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 954 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 955 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 956 return self.sql(this) 957 958 this = exp.cast(this, exp.DataType.Type.DATE) 959 expr = exp.cast(expr, exp.DataType.Type.TIME) 960 961 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 962 963 964def locate_to_strposition(args: t.List) -> exp.Expression: 965 return exp.StrPosition( 966 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 967 ) 968 969 970def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 971 return self.func( 972 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 973 ) 974 975 976def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 977 return self.sql( 978 exp.Substring( 979 this=expression.this, start=exp.Literal.number(1), length=expression.expression 980 ) 981 ) 982 983 984def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 985 return self.sql( 986 exp.Substring( 987 this=expression.this, 988 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 989 ) 990 ) 991 992 993def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 994 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 995 996 997def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 998 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 999 1000 1001# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1002def encode_decode_sql( 1003 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1004) -> str: 1005 charset = expression.args.get("charset") 1006 if charset and charset.name.lower() != "utf-8": 1007 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1008 1009 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1010 1011 1012def min_or_least(self: Generator, expression: exp.Min) -> str: 1013 name = "LEAST" if expression.expressions else "MIN" 1014 return rename_func(name)(self, expression) 1015 1016 1017def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1018 name = "GREATEST" if expression.expressions else "MAX" 1019 return rename_func(name)(self, expression) 1020 1021 1022def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1023 cond = expression.this 1024 1025 if isinstance(expression.this, exp.Distinct): 1026 cond = expression.this.expressions[0] 1027 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1028 1029 return self.func("sum", exp.func("if", cond, 1, 0)) 1030 1031 1032def trim_sql(self: Generator, expression: exp.Trim) -> str: 1033 target = self.sql(expression, "this") 1034 trim_type = self.sql(expression, "position") 1035 remove_chars = self.sql(expression, "expression") 1036 collation = self.sql(expression, "collation") 1037 1038 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1039 if not remove_chars and not collation: 1040 return self.trim_sql(expression) 1041 1042 trim_type = f"{trim_type} " if trim_type else "" 1043 remove_chars = f"{remove_chars} " if remove_chars else "" 1044 from_part = "FROM " if trim_type or remove_chars else "" 1045 collation = f" COLLATE {collation}" if collation else "" 1046 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1047 1048 1049def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1050 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1051 1052 1053def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1054 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1055 1056 1057def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1058 delim, *rest_args = expression.expressions 1059 return self.sql( 1060 reduce( 1061 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1062 rest_args, 1063 ) 1064 ) 1065 1066 1067def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1068 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1069 if bad_args: 1070 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1071 1072 return self.func( 1073 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1074 ) 1075 1076 1077def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1078 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1079 if bad_args: 1080 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1081 1082 return self.func( 1083 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1084 ) 1085 1086 1087def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1088 names = [] 1089 for agg in aggregations: 1090 if isinstance(agg, exp.Alias): 1091 names.append(agg.alias) 1092 else: 1093 """ 1094 This case corresponds to aggregations without aliases being used as suffixes 1095 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1096 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1097 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1098 """ 1099 agg_all_unquoted = agg.transform( 1100 lambda node: ( 1101 exp.Identifier(this=node.name, quoted=False) 1102 if isinstance(node, exp.Identifier) 1103 else node 1104 ) 1105 ) 1106 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1107 1108 return names 1109 1110 1111def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1112 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1113 1114 1115# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1116def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1117 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1118 1119 1120def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1121 return self.func("MAX", expression.this) 1122 1123 1124def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1125 a = self.sql(expression.left) 1126 b = self.sql(expression.right) 1127 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1128 1129 1130def is_parse_json(expression: exp.Expression) -> bool: 1131 return isinstance(expression, exp.ParseJSON) or ( 1132 isinstance(expression, exp.Cast) and expression.is_type("json") 1133 ) 1134 1135 1136def isnull_to_is_null(args: t.List) -> exp.Expression: 1137 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1138 1139 1140def generatedasidentitycolumnconstraint_sql( 1141 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1142) -> str: 1143 start = self.sql(expression, "start") or "1" 1144 increment = self.sql(expression, "increment") or "1" 1145 return f"IDENTITY({start}, {increment})" 1146 1147 1148def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1149 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1150 if expression.args.get("count"): 1151 self.unsupported(f"Only two arguments are supported in function {name}.") 1152 1153 return self.func(name, expression.this, expression.expression) 1154 1155 return _arg_max_or_min_sql 1156 1157 1158def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1159 this = expression.this.copy() 1160 1161 return_type = expression.return_type 1162 if return_type.is_type(exp.DataType.Type.DATE): 1163 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1164 # can truncate timestamp strings, because some dialects can't cast them to DATE 1165 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1166 1167 expression.this.replace(exp.cast(this, return_type)) 1168 return expression 1169 1170 1171def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1172 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1173 if cast and isinstance(expression, exp.TsOrDsAdd): 1174 expression = ts_or_ds_add_cast(expression) 1175 1176 return self.func( 1177 name, 1178 unit_to_var(expression), 1179 expression.expression, 1180 expression.this, 1181 ) 1182 1183 return _delta_sql 1184 1185 1186def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1187 unit = expression.args.get("unit") 1188 1189 if isinstance(unit, exp.Placeholder): 1190 return unit 1191 if unit: 1192 return exp.Literal.string(unit.name) 1193 return exp.Literal.string(default) if default else None 1194 1195 1196def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1197 unit = expression.args.get("unit") 1198 1199 if isinstance(unit, (exp.Var, exp.Placeholder)): 1200 return unit 1201 return exp.Var(this=default) if default else None 1202 1203 1204@t.overload 1205def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1206 pass 1207 1208 1209@t.overload 1210def map_date_part( 1211 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1212) -> t.Optional[exp.Expression]: 1213 pass 1214 1215 1216def map_date_part(part, dialect: DialectType = Dialect): 1217 mapped = ( 1218 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1219 ) 1220 return exp.var(mapped) if mapped else part 1221 1222 1223def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1224 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1225 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1226 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1227 1228 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1229 1230 1231def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1232 """Remove table refs from columns in when statements.""" 1233 alias = expression.this.args.get("alias") 1234 1235 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1236 return self.dialect.normalize_identifier(identifier).name if identifier else None 1237 1238 targets = {normalize(expression.this.this)} 1239 1240 if alias: 1241 targets.add(normalize(alias.this)) 1242 1243 for when in expression.expressions: 1244 when.transform( 1245 lambda node: ( 1246 exp.column(node.this) 1247 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1248 else node 1249 ), 1250 copy=False, 1251 ) 1252 1253 return self.merge_sql(expression) 1254 1255 1256def build_json_extract_path( 1257 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1258) -> t.Callable[[t.List], F]: 1259 def _builder(args: t.List) -> F: 1260 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1261 for arg in args[1:]: 1262 if not isinstance(arg, exp.Literal): 1263 # We use the fallback parser because we can't really transpile non-literals safely 1264 return expr_type.from_arg_list(args) 1265 1266 text = arg.name 1267 if is_int(text): 1268 index = int(text) 1269 segments.append( 1270 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1271 ) 1272 else: 1273 segments.append(exp.JSONPathKey(this=text)) 1274 1275 # This is done to avoid failing in the expression validator due to the arg count 1276 del args[2:] 1277 return expr_type( 1278 this=seq_get(args, 0), 1279 expression=exp.JSONPath(expressions=segments), 1280 only_json_types=arrow_req_json_type, 1281 ) 1282 1283 return _builder 1284 1285 1286def json_extract_segments( 1287 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1288) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1289 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1290 path = expression.expression 1291 if not isinstance(path, exp.JSONPath): 1292 return rename_func(name)(self, expression) 1293 1294 segments = [] 1295 for segment in path.expressions: 1296 path = self.sql(segment) 1297 if path: 1298 if isinstance(segment, exp.JSONPathPart) and ( 1299 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1300 ): 1301 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1302 1303 segments.append(path) 1304 1305 if op: 1306 return f" {op} ".join([self.sql(expression.this), *segments]) 1307 return self.func(name, expression.this, *segments) 1308 1309 return _json_extract_segments 1310 1311 1312def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1313 if isinstance(expression.this, exp.JSONPathWildcard): 1314 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1315 1316 return expression.name 1317 1318 1319def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1320 cond = expression.expression 1321 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1322 alias = cond.expressions[0] 1323 cond = cond.this 1324 elif isinstance(cond, exp.Predicate): 1325 alias = "_u" 1326 else: 1327 self.unsupported("Unsupported filter condition") 1328 return "" 1329 1330 unnest = exp.Unnest(expressions=[expression.this]) 1331 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1332 return self.sql(exp.Array(expressions=[filtered])) 1333 1334 1335def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1336 return self.func( 1337 "TO_NUMBER", 1338 expression.this, 1339 expression.args.get("format"), 1340 expression.args.get("nlsparam"), 1341 ) 1342 1343 1344def build_default_decimal_type( 1345 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1346) -> t.Callable[[exp.DataType], exp.DataType]: 1347 def _builder(dtype: exp.DataType) -> exp.DataType: 1348 if dtype.expressions or precision is None: 1349 return dtype 1350 1351 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1352 return exp.DataType.build(f"DECIMAL({params})") 1353 1354 return _builder 1355 1356 1357def build_timestamp_from_parts(args: t.List) -> exp.Func: 1358 if len(args) == 2: 1359 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1360 # so we parse this into Anonymous for now instead of introducing complexity 1361 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1362 1363 return exp.TimestampFromParts.from_arg_list(args) 1364 1365 1366def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1367 return self.func(f"SHA{expression.text('length') or '256'}", expression.this)
41class Dialects(str, Enum): 42 """Dialects supported by SQLGLot.""" 43 44 DIALECT = "" 45 46 ATHENA = "athena" 47 BIGQUERY = "bigquery" 48 CLICKHOUSE = "clickhouse" 49 DATABRICKS = "databricks" 50 DORIS = "doris" 51 DRILL = "drill" 52 DUCKDB = "duckdb" 53 HIVE = "hive" 54 MATERIALIZE = "materialize" 55 MYSQL = "mysql" 56 ORACLE = "oracle" 57 POSTGRES = "postgres" 58 PRESTO = "presto" 59 PRQL = "prql" 60 REDSHIFT = "redshift" 61 RISINGWAVE = "risingwave" 62 SNOWFLAKE = "snowflake" 63 SPARK = "spark" 64 SPARK2 = "spark2" 65 SQLITE = "sqlite" 66 STARROCKS = "starrocks" 67 TABLEAU = "tableau" 68 TERADATA = "teradata" 69 TRINO = "trino" 70 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
73class NormalizationStrategy(str, AutoName): 74 """Specifies the strategy according to which identifiers should be normalized.""" 75 76 LOWERCASE = auto() 77 """Unquoted identifiers are lowercased.""" 78 79 UPPERCASE = auto() 80 """Unquoted identifiers are uppercased.""" 81 82 CASE_SENSITIVE = auto() 83 """Always case-sensitive, regardless of quotes.""" 84 85 CASE_INSENSITIVE = auto() 86 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
205class Dialect(metaclass=_Dialect): 206 INDEX_OFFSET = 0 207 """The base index offset for arrays.""" 208 209 WEEK_OFFSET = 0 210 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 211 212 UNNEST_COLUMN_ONLY = False 213 """Whether `UNNEST` table aliases are treated as column aliases.""" 214 215 ALIAS_POST_TABLESAMPLE = False 216 """Whether the table alias comes after tablesample.""" 217 218 TABLESAMPLE_SIZE_IS_PERCENT = False 219 """Whether a size in the table sample clause represents percentage.""" 220 221 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 222 """Specifies the strategy according to which identifiers should be normalized.""" 223 224 IDENTIFIERS_CAN_START_WITH_DIGIT = False 225 """Whether an unquoted identifier can start with a digit.""" 226 227 DPIPE_IS_STRING_CONCAT = True 228 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 229 230 STRICT_STRING_CONCAT = False 231 """Whether `CONCAT`'s arguments must be strings.""" 232 233 SUPPORTS_USER_DEFINED_TYPES = True 234 """Whether user-defined data types are supported.""" 235 236 SUPPORTS_SEMI_ANTI_JOIN = True 237 """Whether `SEMI` or `ANTI` joins are supported.""" 238 239 SUPPORTS_COLUMN_JOIN_MARKS = False 240 """Whether the old-style outer join (+) syntax is supported.""" 241 242 COPY_PARAMS_ARE_CSV = True 243 """Separator of COPY statement parameters.""" 244 245 NORMALIZE_FUNCTIONS: bool | str = "upper" 246 """ 247 Determines how function names are going to be normalized. 248 Possible values: 249 "upper" or True: Convert names to uppercase. 250 "lower": Convert names to lowercase. 251 False: Disables function name normalization. 252 """ 253 254 LOG_BASE_FIRST: t.Optional[bool] = True 255 """ 256 Whether the base comes first in the `LOG` function. 257 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 258 """ 259 260 NULL_ORDERING = "nulls_are_small" 261 """ 262 Default `NULL` ordering method to use if not explicitly set. 263 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 264 """ 265 266 TYPED_DIVISION = False 267 """ 268 Whether the behavior of `a / b` depends on the types of `a` and `b`. 269 False means `a / b` is always float division. 270 True means `a / b` is integer division if both `a` and `b` are integers. 271 """ 272 273 SAFE_DIVISION = False 274 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 275 276 CONCAT_COALESCE = False 277 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 278 279 HEX_LOWERCASE = False 280 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 281 282 DATE_FORMAT = "'%Y-%m-%d'" 283 DATEINT_FORMAT = "'%Y%m%d'" 284 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 285 286 TIME_MAPPING: t.Dict[str, str] = {} 287 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 288 289 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 290 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 291 FORMAT_MAPPING: t.Dict[str, str] = {} 292 """ 293 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 294 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 295 """ 296 297 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 298 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 299 300 PSEUDOCOLUMNS: t.Set[str] = set() 301 """ 302 Columns that are auto-generated by the engine corresponding to this dialect. 303 For example, such columns may be excluded from `SELECT *` queries. 304 """ 305 306 PREFER_CTE_ALIAS_COLUMN = False 307 """ 308 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 309 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 310 any projection aliases in the subquery. 311 312 For example, 313 WITH y(c) AS ( 314 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 315 ) SELECT c FROM y; 316 317 will be rewritten as 318 319 WITH y(c) AS ( 320 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 321 ) SELECT c FROM y; 322 """ 323 324 COPY_PARAMS_ARE_CSV = True 325 """ 326 Whether COPY statement parameters are separated by comma or whitespace 327 """ 328 329 # --- Autofilled --- 330 331 tokenizer_class = Tokenizer 332 jsonpath_tokenizer_class = JSONPathTokenizer 333 parser_class = Parser 334 generator_class = Generator 335 336 # A trie of the time_mapping keys 337 TIME_TRIE: t.Dict = {} 338 FORMAT_TRIE: t.Dict = {} 339 340 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 341 INVERSE_TIME_TRIE: t.Dict = {} 342 343 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 344 345 # Delimiters for string literals and identifiers 346 QUOTE_START = "'" 347 QUOTE_END = "'" 348 IDENTIFIER_START = '"' 349 IDENTIFIER_END = '"' 350 351 # Delimiters for bit, hex, byte and unicode literals 352 BIT_START: t.Optional[str] = None 353 BIT_END: t.Optional[str] = None 354 HEX_START: t.Optional[str] = None 355 HEX_END: t.Optional[str] = None 356 BYTE_START: t.Optional[str] = None 357 BYTE_END: t.Optional[str] = None 358 UNICODE_START: t.Optional[str] = None 359 UNICODE_END: t.Optional[str] = None 360 361 DATE_PART_MAPPING = { 362 "Y": "YEAR", 363 "YY": "YEAR", 364 "YYY": "YEAR", 365 "YYYY": "YEAR", 366 "YR": "YEAR", 367 "YEARS": "YEAR", 368 "YRS": "YEAR", 369 "MM": "MONTH", 370 "MON": "MONTH", 371 "MONS": "MONTH", 372 "MONTHS": "MONTH", 373 "D": "DAY", 374 "DD": "DAY", 375 "DAYS": "DAY", 376 "DAYOFMONTH": "DAY", 377 "DAY OF WEEK": "DAYOFWEEK", 378 "WEEKDAY": "DAYOFWEEK", 379 "DOW": "DAYOFWEEK", 380 "DW": "DAYOFWEEK", 381 "WEEKDAY_ISO": "DAYOFWEEKISO", 382 "DOW_ISO": "DAYOFWEEKISO", 383 "DW_ISO": "DAYOFWEEKISO", 384 "DAY OF YEAR": "DAYOFYEAR", 385 "DOY": "DAYOFYEAR", 386 "DY": "DAYOFYEAR", 387 "W": "WEEK", 388 "WK": "WEEK", 389 "WEEKOFYEAR": "WEEK", 390 "WOY": "WEEK", 391 "WY": "WEEK", 392 "WEEK_ISO": "WEEKISO", 393 "WEEKOFYEARISO": "WEEKISO", 394 "WEEKOFYEAR_ISO": "WEEKISO", 395 "Q": "QUARTER", 396 "QTR": "QUARTER", 397 "QTRS": "QUARTER", 398 "QUARTERS": "QUARTER", 399 "H": "HOUR", 400 "HH": "HOUR", 401 "HR": "HOUR", 402 "HOURS": "HOUR", 403 "HRS": "HOUR", 404 "M": "MINUTE", 405 "MI": "MINUTE", 406 "MIN": "MINUTE", 407 "MINUTES": "MINUTE", 408 "MINS": "MINUTE", 409 "S": "SECOND", 410 "SEC": "SECOND", 411 "SECONDS": "SECOND", 412 "SECS": "SECOND", 413 "MS": "MILLISECOND", 414 "MSEC": "MILLISECOND", 415 "MSECS": "MILLISECOND", 416 "MSECOND": "MILLISECOND", 417 "MSECONDS": "MILLISECOND", 418 "MILLISEC": "MILLISECOND", 419 "MILLISECS": "MILLISECOND", 420 "MILLISECON": "MILLISECOND", 421 "MILLISECONDS": "MILLISECOND", 422 "US": "MICROSECOND", 423 "USEC": "MICROSECOND", 424 "USECS": "MICROSECOND", 425 "MICROSEC": "MICROSECOND", 426 "MICROSECS": "MICROSECOND", 427 "USECOND": "MICROSECOND", 428 "USECONDS": "MICROSECOND", 429 "MICROSECONDS": "MICROSECOND", 430 "NS": "NANOSECOND", 431 "NSEC": "NANOSECOND", 432 "NANOSEC": "NANOSECOND", 433 "NSECOND": "NANOSECOND", 434 "NSECONDS": "NANOSECOND", 435 "NANOSECS": "NANOSECOND", 436 "EPOCH_SECOND": "EPOCH", 437 "EPOCH_SECONDS": "EPOCH", 438 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 439 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 440 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 441 "TZH": "TIMEZONE_HOUR", 442 "TZM": "TIMEZONE_MINUTE", 443 "DEC": "DECADE", 444 "DECS": "DECADE", 445 "DECADES": "DECADE", 446 "MIL": "MILLENIUM", 447 "MILS": "MILLENIUM", 448 "MILLENIA": "MILLENIUM", 449 "C": "CENTURY", 450 "CENT": "CENTURY", 451 "CENTS": "CENTURY", 452 "CENTURIES": "CENTURY", 453 } 454 455 @classmethod 456 def get_or_raise(cls, dialect: DialectType) -> Dialect: 457 """ 458 Look up a dialect in the global dialect registry and return it if it exists. 459 460 Args: 461 dialect: The target dialect. If this is a string, it can be optionally followed by 462 additional key-value pairs that are separated by commas and are used to specify 463 dialect settings, such as whether the dialect's identifiers are case-sensitive. 464 465 Example: 466 >>> dialect = dialect_class = get_or_raise("duckdb") 467 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 468 469 Returns: 470 The corresponding Dialect instance. 471 """ 472 473 if not dialect: 474 return cls() 475 if isinstance(dialect, _Dialect): 476 return dialect() 477 if isinstance(dialect, Dialect): 478 return dialect 479 if isinstance(dialect, str): 480 try: 481 dialect_name, *kv_pairs = dialect.split(",") 482 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 483 except ValueError: 484 raise ValueError( 485 f"Invalid dialect format: '{dialect}'. " 486 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 487 ) 488 489 result = cls.get(dialect_name.strip()) 490 if not result: 491 from difflib import get_close_matches 492 493 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 494 if similar: 495 similar = f" Did you mean {similar}?" 496 497 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 498 499 return result(**kwargs) 500 501 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 502 503 @classmethod 504 def format_time( 505 cls, expression: t.Optional[str | exp.Expression] 506 ) -> t.Optional[exp.Expression]: 507 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 508 if isinstance(expression, str): 509 return exp.Literal.string( 510 # the time formats are quoted 511 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 512 ) 513 514 if expression and expression.is_string: 515 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 516 517 return expression 518 519 def __init__(self, **kwargs) -> None: 520 normalization_strategy = kwargs.pop("normalization_strategy", None) 521 522 if normalization_strategy is None: 523 self.normalization_strategy = self.NORMALIZATION_STRATEGY 524 else: 525 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 526 527 self.settings = kwargs 528 529 def __eq__(self, other: t.Any) -> bool: 530 # Does not currently take dialect state into account 531 return type(self) == other 532 533 def __hash__(self) -> int: 534 # Does not currently take dialect state into account 535 return hash(type(self)) 536 537 def normalize_identifier(self, expression: E) -> E: 538 """ 539 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 540 541 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 542 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 543 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 544 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 545 546 There are also dialects like Spark, which are case-insensitive even when quotes are 547 present, and dialects like MySQL, whose resolution rules match those employed by the 548 underlying operating system, for example they may always be case-sensitive in Linux. 549 550 Finally, the normalization behavior of some engines can even be controlled through flags, 551 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 552 553 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 554 that it can analyze queries in the optimizer and successfully capture their semantics. 555 """ 556 if ( 557 isinstance(expression, exp.Identifier) 558 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 559 and ( 560 not expression.quoted 561 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 562 ) 563 ): 564 expression.set( 565 "this", 566 ( 567 expression.this.upper() 568 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 569 else expression.this.lower() 570 ), 571 ) 572 573 return expression 574 575 def case_sensitive(self, text: str) -> bool: 576 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 577 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 578 return False 579 580 unsafe = ( 581 str.islower 582 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 583 else str.isupper 584 ) 585 return any(unsafe(char) for char in text) 586 587 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 588 """Checks if text can be identified given an identify option. 589 590 Args: 591 text: The text to check. 592 identify: 593 `"always"` or `True`: Always returns `True`. 594 `"safe"`: Only returns `True` if the identifier is case-insensitive. 595 596 Returns: 597 Whether the given text can be identified. 598 """ 599 if identify is True or identify == "always": 600 return True 601 602 if identify == "safe": 603 return not self.case_sensitive(text) 604 605 return False 606 607 def quote_identifier(self, expression: E, identify: bool = True) -> E: 608 """ 609 Adds quotes to a given identifier. 610 611 Args: 612 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 613 identify: If set to `False`, the quotes will only be added if the identifier is deemed 614 "unsafe", with respect to its characters and this dialect's normalization strategy. 615 """ 616 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 617 name = expression.this 618 expression.set( 619 "quoted", 620 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 621 ) 622 623 return expression 624 625 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 626 if isinstance(path, exp.Literal): 627 path_text = path.name 628 if path.is_number: 629 path_text = f"[{path_text}]" 630 try: 631 return parse_json_path(path_text, self) 632 except ParseError as e: 633 logger.warning(f"Invalid JSON path syntax. {str(e)}") 634 635 return path 636 637 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 638 return self.parser(**opts).parse(self.tokenize(sql), sql) 639 640 def parse_into( 641 self, expression_type: exp.IntoType, sql: str, **opts 642 ) -> t.List[t.Optional[exp.Expression]]: 643 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 644 645 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 646 return self.generator(**opts).generate(expression, copy=copy) 647 648 def transpile(self, sql: str, **opts) -> t.List[str]: 649 return [ 650 self.generate(expression, copy=False, **opts) if expression else "" 651 for expression in self.parse(sql) 652 ] 653 654 def tokenize(self, sql: str) -> t.List[Token]: 655 return self.tokenizer.tokenize(sql) 656 657 @property 658 def tokenizer(self) -> Tokenizer: 659 if not hasattr(self, "_tokenizer"): 660 self._tokenizer = self.tokenizer_class(dialect=self) 661 return self._tokenizer 662 663 @property 664 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 665 if not hasattr(self, "_jsonpath_tokenizer"): 666 self._jsonpath_tokenizer = self.jsonpath_tokenizer_class(dialect=self) 667 return self._jsonpath_tokenizer 668 669 def parser(self, **opts) -> Parser: 670 return self.parser_class(dialect=self, **opts) 671 672 def generator(self, **opts) -> Generator: 673 return self.generator_class(dialect=self, **opts)
519 def __init__(self, **kwargs) -> None: 520 normalization_strategy = kwargs.pop("normalization_strategy", None) 521 522 if normalization_strategy is None: 523 self.normalization_strategy = self.NORMALIZATION_STRATEGY 524 else: 525 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 526 527 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
455 @classmethod 456 def get_or_raise(cls, dialect: DialectType) -> Dialect: 457 """ 458 Look up a dialect in the global dialect registry and return it if it exists. 459 460 Args: 461 dialect: The target dialect. If this is a string, it can be optionally followed by 462 additional key-value pairs that are separated by commas and are used to specify 463 dialect settings, such as whether the dialect's identifiers are case-sensitive. 464 465 Example: 466 >>> dialect = dialect_class = get_or_raise("duckdb") 467 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 468 469 Returns: 470 The corresponding Dialect instance. 471 """ 472 473 if not dialect: 474 return cls() 475 if isinstance(dialect, _Dialect): 476 return dialect() 477 if isinstance(dialect, Dialect): 478 return dialect 479 if isinstance(dialect, str): 480 try: 481 dialect_name, *kv_pairs = dialect.split(",") 482 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 483 except ValueError: 484 raise ValueError( 485 f"Invalid dialect format: '{dialect}'. " 486 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 487 ) 488 489 result = cls.get(dialect_name.strip()) 490 if not result: 491 from difflib import get_close_matches 492 493 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 494 if similar: 495 similar = f" Did you mean {similar}?" 496 497 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 498 499 return result(**kwargs) 500 501 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
503 @classmethod 504 def format_time( 505 cls, expression: t.Optional[str | exp.Expression] 506 ) -> t.Optional[exp.Expression]: 507 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 508 if isinstance(expression, str): 509 return exp.Literal.string( 510 # the time formats are quoted 511 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 512 ) 513 514 if expression and expression.is_string: 515 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 516 517 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
537 def normalize_identifier(self, expression: E) -> E: 538 """ 539 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 540 541 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 542 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 543 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 544 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 545 546 There are also dialects like Spark, which are case-insensitive even when quotes are 547 present, and dialects like MySQL, whose resolution rules match those employed by the 548 underlying operating system, for example they may always be case-sensitive in Linux. 549 550 Finally, the normalization behavior of some engines can even be controlled through flags, 551 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 552 553 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 554 that it can analyze queries in the optimizer and successfully capture their semantics. 555 """ 556 if ( 557 isinstance(expression, exp.Identifier) 558 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 559 and ( 560 not expression.quoted 561 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 562 ) 563 ): 564 expression.set( 565 "this", 566 ( 567 expression.this.upper() 568 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 569 else expression.this.lower() 570 ), 571 ) 572 573 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
575 def case_sensitive(self, text: str) -> bool: 576 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 577 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 578 return False 579 580 unsafe = ( 581 str.islower 582 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 583 else str.isupper 584 ) 585 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
587 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 588 """Checks if text can be identified given an identify option. 589 590 Args: 591 text: The text to check. 592 identify: 593 `"always"` or `True`: Always returns `True`. 594 `"safe"`: Only returns `True` if the identifier is case-insensitive. 595 596 Returns: 597 Whether the given text can be identified. 598 """ 599 if identify is True or identify == "always": 600 return True 601 602 if identify == "safe": 603 return not self.case_sensitive(text) 604 605 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
607 def quote_identifier(self, expression: E, identify: bool = True) -> E: 608 """ 609 Adds quotes to a given identifier. 610 611 Args: 612 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 613 identify: If set to `False`, the quotes will only be added if the identifier is deemed 614 "unsafe", with respect to its characters and this dialect's normalization strategy. 615 """ 616 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 617 name = expression.this 618 expression.set( 619 "quoted", 620 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 621 ) 622 623 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
625 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 626 if isinstance(path, exp.Literal): 627 path_text = path.name 628 if path.is_number: 629 path_text = f"[{path_text}]" 630 try: 631 return parse_json_path(path_text, self) 632 except ParseError as e: 633 logger.warning(f"Invalid JSON path syntax. {str(e)}") 634 635 return path
689def if_sql( 690 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 691) -> t.Callable[[Generator, exp.If], str]: 692 def _if_sql(self: Generator, expression: exp.If) -> str: 693 return self.func( 694 name, 695 expression.this, 696 expression.args.get("true"), 697 expression.args.get("false") or false_value, 698 ) 699 700 return _if_sql
703def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 704 this = expression.this 705 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 706 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 707 708 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
774def str_position_sql( 775 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 776) -> str: 777 this = self.sql(expression, "this") 778 substr = self.sql(expression, "substr") 779 position = self.sql(expression, "position") 780 instance = expression.args.get("instance") if generate_instance else None 781 position_offset = "" 782 783 if position: 784 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 785 this = self.func("SUBSTR", this, position) 786 position_offset = f" + {position} - 1" 787 788 return self.func("STRPOS", this, substr, instance) + position_offset
797def var_map_sql( 798 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 799) -> str: 800 keys = expression.args["keys"] 801 values = expression.args["values"] 802 803 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 804 self.unsupported("Cannot convert array columns into map.") 805 return self.func(map_func_name, keys, values) 806 807 args = [] 808 for key, value in zip(keys.expressions, values.expressions): 809 args.append(self.sql(key)) 810 args.append(self.sql(value)) 811 812 return self.func(map_func_name, *args)
815def build_formatted_time( 816 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 817) -> t.Callable[[t.List], E]: 818 """Helper used for time expressions. 819 820 Args: 821 exp_class: the expression class to instantiate. 822 dialect: target sql dialect. 823 default: the default format, True being time. 824 825 Returns: 826 A callable that can be used to return the appropriately formatted time expression. 827 """ 828 829 def _builder(args: t.List): 830 return exp_class( 831 this=seq_get(args, 0), 832 format=Dialect[dialect].format_time( 833 seq_get(args, 1) 834 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 835 ), 836 ) 837 838 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
841def time_format( 842 dialect: DialectType = None, 843) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 844 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 845 """ 846 Returns the time format for a given expression, unless it's equivalent 847 to the default time format of the dialect of interest. 848 """ 849 time_format = self.format_time(expression) 850 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 851 852 return _time_format
855def build_date_delta( 856 exp_class: t.Type[E], 857 unit_mapping: t.Optional[t.Dict[str, str]] = None, 858 default_unit: t.Optional[str] = "DAY", 859) -> t.Callable[[t.List], E]: 860 def _builder(args: t.List) -> E: 861 unit_based = len(args) == 3 862 this = args[2] if unit_based else seq_get(args, 0) 863 unit = None 864 if unit_based or default_unit: 865 unit = args[0] if unit_based else exp.Literal.string(default_unit) 866 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 867 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 868 869 return _builder
872def build_date_delta_with_interval( 873 expression_class: t.Type[E], 874) -> t.Callable[[t.List], t.Optional[E]]: 875 def _builder(args: t.List) -> t.Optional[E]: 876 if len(args) < 2: 877 return None 878 879 interval = args[1] 880 881 if not isinstance(interval, exp.Interval): 882 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 883 884 expression = interval.this 885 if expression and expression.is_string: 886 expression = exp.Literal.number(expression.this) 887 888 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 889 890 return _builder
902def date_add_interval_sql( 903 data_type: str, kind: str 904) -> t.Callable[[Generator, exp.Expression], str]: 905 def func(self: Generator, expression: exp.Expression) -> str: 906 this = self.sql(expression, "this") 907 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 908 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 909 910 return func
913def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 914 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 915 args = [unit_to_str(expression), expression.this] 916 if zone: 917 args.append(expression.args.get("zone")) 918 return self.func("DATE_TRUNC", *args) 919 920 return _timestamptrunc_sql
923def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 924 zone = expression.args.get("zone") 925 if not zone: 926 from sqlglot.optimizer.annotate_types import annotate_types 927 928 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 929 return self.sql(exp.cast(expression.this, target_type)) 930 if zone.name.lower() in TIMEZONES: 931 return self.sql( 932 exp.AtTimeZone( 933 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 934 zone=zone, 935 ) 936 ) 937 return self.func("TIMESTAMP", expression.this, zone)
940def no_time_sql(self: Generator, expression: exp.Time) -> str: 941 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 942 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 943 expr = exp.cast( 944 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 945 ) 946 return self.sql(expr)
949def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 950 this = expression.this 951 expr = expression.expression 952 953 if expr.name.lower() in TIMEZONES: 954 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 955 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 956 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 957 return self.sql(this) 958 959 this = exp.cast(this, exp.DataType.Type.DATE) 960 expr = exp.cast(expr, exp.DataType.Type.TIME) 961 962 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1003def encode_decode_sql( 1004 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1005) -> str: 1006 charset = expression.args.get("charset") 1007 if charset and charset.name.lower() != "utf-8": 1008 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1009 1010 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1023def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1024 cond = expression.this 1025 1026 if isinstance(expression.this, exp.Distinct): 1027 cond = expression.this.expressions[0] 1028 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1029 1030 return self.func("sum", exp.func("if", cond, 1, 0))
1033def trim_sql(self: Generator, expression: exp.Trim) -> str: 1034 target = self.sql(expression, "this") 1035 trim_type = self.sql(expression, "position") 1036 remove_chars = self.sql(expression, "expression") 1037 collation = self.sql(expression, "collation") 1038 1039 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1040 if not remove_chars and not collation: 1041 return self.trim_sql(expression) 1042 1043 trim_type = f"{trim_type} " if trim_type else "" 1044 remove_chars = f"{remove_chars} " if remove_chars else "" 1045 from_part = "FROM " if trim_type or remove_chars else "" 1046 collation = f" COLLATE {collation}" if collation else "" 1047 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1068def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1069 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1070 if bad_args: 1071 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1072 1073 return self.func( 1074 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1075 )
1078def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1079 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1080 if bad_args: 1081 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1082 1083 return self.func( 1084 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1085 )
1088def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1089 names = [] 1090 for agg in aggregations: 1091 if isinstance(agg, exp.Alias): 1092 names.append(agg.alias) 1093 else: 1094 """ 1095 This case corresponds to aggregations without aliases being used as suffixes 1096 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1097 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1098 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1099 """ 1100 agg_all_unquoted = agg.transform( 1101 lambda node: ( 1102 exp.Identifier(this=node.name, quoted=False) 1103 if isinstance(node, exp.Identifier) 1104 else node 1105 ) 1106 ) 1107 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1108 1109 return names
1149def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1150 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1151 if expression.args.get("count"): 1152 self.unsupported(f"Only two arguments are supported in function {name}.") 1153 1154 return self.func(name, expression.this, expression.expression) 1155 1156 return _arg_max_or_min_sql
1159def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1160 this = expression.this.copy() 1161 1162 return_type = expression.return_type 1163 if return_type.is_type(exp.DataType.Type.DATE): 1164 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1165 # can truncate timestamp strings, because some dialects can't cast them to DATE 1166 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1167 1168 expression.this.replace(exp.cast(this, return_type)) 1169 return expression
1172def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1173 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1174 if cast and isinstance(expression, exp.TsOrDsAdd): 1175 expression = ts_or_ds_add_cast(expression) 1176 1177 return self.func( 1178 name, 1179 unit_to_var(expression), 1180 expression.expression, 1181 expression.this, 1182 ) 1183 1184 return _delta_sql
1187def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1188 unit = expression.args.get("unit") 1189 1190 if isinstance(unit, exp.Placeholder): 1191 return unit 1192 if unit: 1193 return exp.Literal.string(unit.name) 1194 return exp.Literal.string(default) if default else None
1224def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1225 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1226 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1227 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1228 1229 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1232def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1233 """Remove table refs from columns in when statements.""" 1234 alias = expression.this.args.get("alias") 1235 1236 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1237 return self.dialect.normalize_identifier(identifier).name if identifier else None 1238 1239 targets = {normalize(expression.this.this)} 1240 1241 if alias: 1242 targets.add(normalize(alias.this)) 1243 1244 for when in expression.expressions: 1245 when.transform( 1246 lambda node: ( 1247 exp.column(node.this) 1248 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1249 else node 1250 ), 1251 copy=False, 1252 ) 1253 1254 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1257def build_json_extract_path( 1258 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1259) -> t.Callable[[t.List], F]: 1260 def _builder(args: t.List) -> F: 1261 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1262 for arg in args[1:]: 1263 if not isinstance(arg, exp.Literal): 1264 # We use the fallback parser because we can't really transpile non-literals safely 1265 return expr_type.from_arg_list(args) 1266 1267 text = arg.name 1268 if is_int(text): 1269 index = int(text) 1270 segments.append( 1271 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1272 ) 1273 else: 1274 segments.append(exp.JSONPathKey(this=text)) 1275 1276 # This is done to avoid failing in the expression validator due to the arg count 1277 del args[2:] 1278 return expr_type( 1279 this=seq_get(args, 0), 1280 expression=exp.JSONPath(expressions=segments), 1281 only_json_types=arrow_req_json_type, 1282 ) 1283 1284 return _builder
1287def json_extract_segments( 1288 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1289) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1290 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1291 path = expression.expression 1292 if not isinstance(path, exp.JSONPath): 1293 return rename_func(name)(self, expression) 1294 1295 segments = [] 1296 for segment in path.expressions: 1297 path = self.sql(segment) 1298 if path: 1299 if isinstance(segment, exp.JSONPathPart) and ( 1300 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1301 ): 1302 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1303 1304 segments.append(path) 1305 1306 if op: 1307 return f" {op} ".join([self.sql(expression.this), *segments]) 1308 return self.func(name, expression.this, *segments) 1309 1310 return _json_extract_segments
1320def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1321 cond = expression.expression 1322 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1323 alias = cond.expressions[0] 1324 cond = cond.this 1325 elif isinstance(cond, exp.Predicate): 1326 alias = "_u" 1327 else: 1328 self.unsupported("Unsupported filter condition") 1329 return "" 1330 1331 unnest = exp.Unnest(expressions=[expression.this]) 1332 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1333 return self.sql(exp.Array(expressions=[filtered]))
1345def build_default_decimal_type( 1346 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1347) -> t.Callable[[exp.DataType], exp.DataType]: 1348 def _builder(dtype: exp.DataType) -> exp.DataType: 1349 if dtype.expressions or precision is None: 1350 return dtype 1351 1352 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1353 return exp.DataType.build(f"DECIMAL({params})") 1354 1355 return _builder
1358def build_timestamp_from_parts(args: t.List) -> exp.Func: 1359 if len(args) == 2: 1360 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1361 # so we parse this into Anonymous for now instead of introducing complexity 1362 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1363 1364 return exp.TimestampFromParts.from_arg_list(args)