sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 base = seq_get(bases, 0) 114 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 115 base_parser = (getattr(base, "parser_class", Parser),) 116 base_generator = (getattr(base, "generator_class", Generator),) 117 118 klass.tokenizer_class = klass.__dict__.get( 119 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 120 ) 121 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 122 klass.generator_class = klass.__dict__.get( 123 "Generator", type("Generator", base_generator, {}) 124 ) 125 126 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 127 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 128 klass.tokenizer_class._IDENTIFIERS.items() 129 )[0] 130 131 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 132 return next( 133 ( 134 (s, e) 135 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 136 if t == token_type 137 ), 138 (None, None), 139 ) 140 141 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 142 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 143 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 144 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 145 146 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 147 klass.UNESCAPED_SEQUENCES = { 148 "\\a": "\a", 149 "\\b": "\b", 150 "\\f": "\f", 151 "\\n": "\n", 152 "\\r": "\r", 153 "\\t": "\t", 154 "\\v": "\v", 155 "\\\\": "\\", 156 **klass.UNESCAPED_SEQUENCES, 157 } 158 159 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 160 161 if enum not in ("", "bigquery"): 162 klass.generator_class.SELECT_KINDS = () 163 164 if enum not in ("", "databricks", "hive", "spark", "spark2"): 165 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 166 for modifier in ("cluster", "distribute", "sort"): 167 modifier_transforms.pop(modifier, None) 168 169 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 170 171 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 172 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 173 TokenType.ANTI, 174 TokenType.SEMI, 175 } 176 177 return klass 178 179 180class Dialect(metaclass=_Dialect): 181 INDEX_OFFSET = 0 182 """The base index offset for arrays.""" 183 184 WEEK_OFFSET = 0 185 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 186 187 UNNEST_COLUMN_ONLY = False 188 """Whether `UNNEST` table aliases are treated as column aliases.""" 189 190 ALIAS_POST_TABLESAMPLE = False 191 """Whether the table alias comes after tablesample.""" 192 193 TABLESAMPLE_SIZE_IS_PERCENT = False 194 """Whether a size in the table sample clause represents percentage.""" 195 196 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 197 """Specifies the strategy according to which identifiers should be normalized.""" 198 199 IDENTIFIERS_CAN_START_WITH_DIGIT = False 200 """Whether an unquoted identifier can start with a digit.""" 201 202 DPIPE_IS_STRING_CONCAT = True 203 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 204 205 STRICT_STRING_CONCAT = False 206 """Whether `CONCAT`'s arguments must be strings.""" 207 208 SUPPORTS_USER_DEFINED_TYPES = True 209 """Whether user-defined data types are supported.""" 210 211 SUPPORTS_SEMI_ANTI_JOIN = True 212 """Whether `SEMI` or `ANTI` joins are supported.""" 213 214 NORMALIZE_FUNCTIONS: bool | str = "upper" 215 """ 216 Determines how function names are going to be normalized. 217 Possible values: 218 "upper" or True: Convert names to uppercase. 219 "lower": Convert names to lowercase. 220 False: Disables function name normalization. 221 """ 222 223 LOG_BASE_FIRST: t.Optional[bool] = True 224 """ 225 Whether the base comes first in the `LOG` function. 226 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 227 """ 228 229 NULL_ORDERING = "nulls_are_small" 230 """ 231 Default `NULL` ordering method to use if not explicitly set. 232 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 233 """ 234 235 TYPED_DIVISION = False 236 """ 237 Whether the behavior of `a / b` depends on the types of `a` and `b`. 238 False means `a / b` is always float division. 239 True means `a / b` is integer division if both `a` and `b` are integers. 240 """ 241 242 SAFE_DIVISION = False 243 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 244 245 CONCAT_COALESCE = False 246 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 247 248 DATE_FORMAT = "'%Y-%m-%d'" 249 DATEINT_FORMAT = "'%Y%m%d'" 250 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 251 252 TIME_MAPPING: t.Dict[str, str] = {} 253 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 254 255 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 256 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 257 FORMAT_MAPPING: t.Dict[str, str] = {} 258 """ 259 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 260 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 261 """ 262 263 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 264 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 265 266 PSEUDOCOLUMNS: t.Set[str] = set() 267 """ 268 Columns that are auto-generated by the engine corresponding to this dialect. 269 For example, such columns may be excluded from `SELECT *` queries. 270 """ 271 272 PREFER_CTE_ALIAS_COLUMN = False 273 """ 274 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 275 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 276 any projection aliases in the subquery. 277 278 For example, 279 WITH y(c) AS ( 280 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 281 ) SELECT c FROM y; 282 283 will be rewritten as 284 285 WITH y(c) AS ( 286 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 287 ) SELECT c FROM y; 288 """ 289 290 # --- Autofilled --- 291 292 tokenizer_class = Tokenizer 293 parser_class = Parser 294 generator_class = Generator 295 296 # A trie of the time_mapping keys 297 TIME_TRIE: t.Dict = {} 298 FORMAT_TRIE: t.Dict = {} 299 300 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 301 INVERSE_TIME_TRIE: t.Dict = {} 302 303 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 304 305 # Delimiters for string literals and identifiers 306 QUOTE_START = "'" 307 QUOTE_END = "'" 308 IDENTIFIER_START = '"' 309 IDENTIFIER_END = '"' 310 311 # Delimiters for bit, hex, byte and unicode literals 312 BIT_START: t.Optional[str] = None 313 BIT_END: t.Optional[str] = None 314 HEX_START: t.Optional[str] = None 315 HEX_END: t.Optional[str] = None 316 BYTE_START: t.Optional[str] = None 317 BYTE_END: t.Optional[str] = None 318 UNICODE_START: t.Optional[str] = None 319 UNICODE_END: t.Optional[str] = None 320 321 @classmethod 322 def get_or_raise(cls, dialect: DialectType) -> Dialect: 323 """ 324 Look up a dialect in the global dialect registry and return it if it exists. 325 326 Args: 327 dialect: The target dialect. If this is a string, it can be optionally followed by 328 additional key-value pairs that are separated by commas and are used to specify 329 dialect settings, such as whether the dialect's identifiers are case-sensitive. 330 331 Example: 332 >>> dialect = dialect_class = get_or_raise("duckdb") 333 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 334 335 Returns: 336 The corresponding Dialect instance. 337 """ 338 339 if not dialect: 340 return cls() 341 if isinstance(dialect, _Dialect): 342 return dialect() 343 if isinstance(dialect, Dialect): 344 return dialect 345 if isinstance(dialect, str): 346 try: 347 dialect_name, *kv_pairs = dialect.split(",") 348 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 349 except ValueError: 350 raise ValueError( 351 f"Invalid dialect format: '{dialect}'. " 352 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 353 ) 354 355 result = cls.get(dialect_name.strip()) 356 if not result: 357 from difflib import get_close_matches 358 359 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 360 if similar: 361 similar = f" Did you mean {similar}?" 362 363 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 364 365 return result(**kwargs) 366 367 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 368 369 @classmethod 370 def format_time( 371 cls, expression: t.Optional[str | exp.Expression] 372 ) -> t.Optional[exp.Expression]: 373 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 374 if isinstance(expression, str): 375 return exp.Literal.string( 376 # the time formats are quoted 377 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 378 ) 379 380 if expression and expression.is_string: 381 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 382 383 return expression 384 385 def __init__(self, **kwargs) -> None: 386 normalization_strategy = kwargs.get("normalization_strategy") 387 388 if normalization_strategy is None: 389 self.normalization_strategy = self.NORMALIZATION_STRATEGY 390 else: 391 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 392 393 def __eq__(self, other: t.Any) -> bool: 394 # Does not currently take dialect state into account 395 return type(self) == other 396 397 def __hash__(self) -> int: 398 # Does not currently take dialect state into account 399 return hash(type(self)) 400 401 def normalize_identifier(self, expression: E) -> E: 402 """ 403 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 404 405 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 406 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 407 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 408 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 409 410 There are also dialects like Spark, which are case-insensitive even when quotes are 411 present, and dialects like MySQL, whose resolution rules match those employed by the 412 underlying operating system, for example they may always be case-sensitive in Linux. 413 414 Finally, the normalization behavior of some engines can even be controlled through flags, 415 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 416 417 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 418 that it can analyze queries in the optimizer and successfully capture their semantics. 419 """ 420 if ( 421 isinstance(expression, exp.Identifier) 422 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 423 and ( 424 not expression.quoted 425 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 426 ) 427 ): 428 expression.set( 429 "this", 430 ( 431 expression.this.upper() 432 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 433 else expression.this.lower() 434 ), 435 ) 436 437 return expression 438 439 def case_sensitive(self, text: str) -> bool: 440 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 441 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 442 return False 443 444 unsafe = ( 445 str.islower 446 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 447 else str.isupper 448 ) 449 return any(unsafe(char) for char in text) 450 451 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 452 """Checks if text can be identified given an identify option. 453 454 Args: 455 text: The text to check. 456 identify: 457 `"always"` or `True`: Always returns `True`. 458 `"safe"`: Only returns `True` if the identifier is case-insensitive. 459 460 Returns: 461 Whether the given text can be identified. 462 """ 463 if identify is True or identify == "always": 464 return True 465 466 if identify == "safe": 467 return not self.case_sensitive(text) 468 469 return False 470 471 def quote_identifier(self, expression: E, identify: bool = True) -> E: 472 """ 473 Adds quotes to a given identifier. 474 475 Args: 476 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 477 identify: If set to `False`, the quotes will only be added if the identifier is deemed 478 "unsafe", with respect to its characters and this dialect's normalization strategy. 479 """ 480 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 481 name = expression.this 482 expression.set( 483 "quoted", 484 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 485 ) 486 487 return expression 488 489 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 490 if isinstance(path, exp.Literal): 491 path_text = path.name 492 if path.is_number: 493 path_text = f"[{path_text}]" 494 495 try: 496 return parse_json_path(path_text) 497 except ParseError as e: 498 logger.warning(f"Invalid JSON path syntax. {str(e)}") 499 500 return path 501 502 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 503 return self.parser(**opts).parse(self.tokenize(sql), sql) 504 505 def parse_into( 506 self, expression_type: exp.IntoType, sql: str, **opts 507 ) -> t.List[t.Optional[exp.Expression]]: 508 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 509 510 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 511 return self.generator(**opts).generate(expression, copy=copy) 512 513 def transpile(self, sql: str, **opts) -> t.List[str]: 514 return [ 515 self.generate(expression, copy=False, **opts) if expression else "" 516 for expression in self.parse(sql) 517 ] 518 519 def tokenize(self, sql: str) -> t.List[Token]: 520 return self.tokenizer.tokenize(sql) 521 522 @property 523 def tokenizer(self) -> Tokenizer: 524 if not hasattr(self, "_tokenizer"): 525 self._tokenizer = self.tokenizer_class(dialect=self) 526 return self._tokenizer 527 528 def parser(self, **opts) -> Parser: 529 return self.parser_class(dialect=self, **opts) 530 531 def generator(self, **opts) -> Generator: 532 return self.generator_class(dialect=self, **opts) 533 534 535DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 536 537 538def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 539 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 540 541 542def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 543 if expression.args.get("accuracy"): 544 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 545 return self.func("APPROX_COUNT_DISTINCT", expression.this) 546 547 548def if_sql( 549 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 550) -> t.Callable[[Generator, exp.If], str]: 551 def _if_sql(self: Generator, expression: exp.If) -> str: 552 return self.func( 553 name, 554 expression.this, 555 expression.args.get("true"), 556 expression.args.get("false") or false_value, 557 ) 558 559 return _if_sql 560 561 562def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 563 this = expression.this 564 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 565 this.replace(exp.cast(this, "json")) 566 567 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 568 569 570def inline_array_sql(self: Generator, expression: exp.Array) -> str: 571 return f"[{self.expressions(expression, flat=True)}]" 572 573 574def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 575 return self.like_sql( 576 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 577 ) 578 579 580def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 581 zone = self.sql(expression, "this") 582 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 583 584 585def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 586 if expression.args.get("recursive"): 587 self.unsupported("Recursive CTEs are unsupported") 588 expression.args["recursive"] = False 589 return self.with_sql(expression) 590 591 592def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 593 n = self.sql(expression, "this") 594 d = self.sql(expression, "expression") 595 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 596 597 598def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 599 self.unsupported("TABLESAMPLE unsupported") 600 return self.sql(expression.this) 601 602 603def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 604 self.unsupported("PIVOT unsupported") 605 return "" 606 607 608def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 609 return self.cast_sql(expression) 610 611 612def no_comment_column_constraint_sql( 613 self: Generator, expression: exp.CommentColumnConstraint 614) -> str: 615 self.unsupported("CommentColumnConstraint unsupported") 616 return "" 617 618 619def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 620 self.unsupported("MAP_FROM_ENTRIES unsupported") 621 return "" 622 623 624def str_position_sql( 625 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 626) -> str: 627 this = self.sql(expression, "this") 628 substr = self.sql(expression, "substr") 629 position = self.sql(expression, "position") 630 instance = expression.args.get("instance") if generate_instance else None 631 position_offset = "" 632 633 if position: 634 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 635 this = self.func("SUBSTR", this, position) 636 position_offset = f" + {position} - 1" 637 638 return self.func("STRPOS", this, substr, instance) + position_offset 639 640 641def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 642 return ( 643 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 644 ) 645 646 647def var_map_sql( 648 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 649) -> str: 650 keys = expression.args["keys"] 651 values = expression.args["values"] 652 653 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 654 self.unsupported("Cannot convert array columns into map.") 655 return self.func(map_func_name, keys, values) 656 657 args = [] 658 for key, value in zip(keys.expressions, values.expressions): 659 args.append(self.sql(key)) 660 args.append(self.sql(value)) 661 662 return self.func(map_func_name, *args) 663 664 665def build_formatted_time( 666 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 667) -> t.Callable[[t.List], E]: 668 """Helper used for time expressions. 669 670 Args: 671 exp_class: the expression class to instantiate. 672 dialect: target sql dialect. 673 default: the default format, True being time. 674 675 Returns: 676 A callable that can be used to return the appropriately formatted time expression. 677 """ 678 679 def _builder(args: t.List): 680 return exp_class( 681 this=seq_get(args, 0), 682 format=Dialect[dialect].format_time( 683 seq_get(args, 1) 684 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 685 ), 686 ) 687 688 return _builder 689 690 691def time_format( 692 dialect: DialectType = None, 693) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 694 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 695 """ 696 Returns the time format for a given expression, unless it's equivalent 697 to the default time format of the dialect of interest. 698 """ 699 time_format = self.format_time(expression) 700 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 701 702 return _time_format 703 704 705def build_date_delta( 706 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 707) -> t.Callable[[t.List], E]: 708 def _builder(args: t.List) -> E: 709 unit_based = len(args) == 3 710 this = args[2] if unit_based else seq_get(args, 0) 711 unit = args[0] if unit_based else exp.Literal.string("DAY") 712 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 713 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 714 715 return _builder 716 717 718def build_date_delta_with_interval( 719 expression_class: t.Type[E], 720) -> t.Callable[[t.List], t.Optional[E]]: 721 def _builder(args: t.List) -> t.Optional[E]: 722 if len(args) < 2: 723 return None 724 725 interval = args[1] 726 727 if not isinstance(interval, exp.Interval): 728 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 729 730 expression = interval.this 731 if expression and expression.is_string: 732 expression = exp.Literal.number(expression.this) 733 734 return expression_class( 735 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 736 ) 737 738 return _builder 739 740 741def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 742 unit = seq_get(args, 0) 743 this = seq_get(args, 1) 744 745 if isinstance(this, exp.Cast) and this.is_type("date"): 746 return exp.DateTrunc(unit=unit, this=this) 747 return exp.TimestampTrunc(this=this, unit=unit) 748 749 750def date_add_interval_sql( 751 data_type: str, kind: str 752) -> t.Callable[[Generator, exp.Expression], str]: 753 def func(self: Generator, expression: exp.Expression) -> str: 754 this = self.sql(expression, "this") 755 unit = expression.args.get("unit") 756 unit = exp.var(unit.name.upper() if unit else "DAY") 757 interval = exp.Interval(this=expression.expression, unit=unit) 758 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 759 760 return func 761 762 763def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 764 return self.func( 765 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 766 ) 767 768 769def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 770 if not expression.expression: 771 from sqlglot.optimizer.annotate_types import annotate_types 772 773 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 774 return self.sql(exp.cast(expression.this, to=target_type)) 775 if expression.text("expression").lower() in TIMEZONES: 776 return self.sql( 777 exp.AtTimeZone( 778 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 779 zone=expression.expression, 780 ) 781 ) 782 return self.func("TIMESTAMP", expression.this, expression.expression) 783 784 785def locate_to_strposition(args: t.List) -> exp.Expression: 786 return exp.StrPosition( 787 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 788 ) 789 790 791def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 792 return self.func( 793 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 794 ) 795 796 797def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 798 return self.sql( 799 exp.Substring( 800 this=expression.this, start=exp.Literal.number(1), length=expression.expression 801 ) 802 ) 803 804 805def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 806 return self.sql( 807 exp.Substring( 808 this=expression.this, 809 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 810 ) 811 ) 812 813 814def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 815 return self.sql(exp.cast(expression.this, "timestamp")) 816 817 818def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 819 return self.sql(exp.cast(expression.this, "date")) 820 821 822# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 823def encode_decode_sql( 824 self: Generator, expression: exp.Expression, name: str, replace: bool = True 825) -> str: 826 charset = expression.args.get("charset") 827 if charset and charset.name.lower() != "utf-8": 828 self.unsupported(f"Expected utf-8 character set, got {charset}.") 829 830 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 831 832 833def min_or_least(self: Generator, expression: exp.Min) -> str: 834 name = "LEAST" if expression.expressions else "MIN" 835 return rename_func(name)(self, expression) 836 837 838def max_or_greatest(self: Generator, expression: exp.Max) -> str: 839 name = "GREATEST" if expression.expressions else "MAX" 840 return rename_func(name)(self, expression) 841 842 843def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 844 cond = expression.this 845 846 if isinstance(expression.this, exp.Distinct): 847 cond = expression.this.expressions[0] 848 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 849 850 return self.func("sum", exp.func("if", cond, 1, 0)) 851 852 853def trim_sql(self: Generator, expression: exp.Trim) -> str: 854 target = self.sql(expression, "this") 855 trim_type = self.sql(expression, "position") 856 remove_chars = self.sql(expression, "expression") 857 collation = self.sql(expression, "collation") 858 859 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 860 if not remove_chars and not collation: 861 return self.trim_sql(expression) 862 863 trim_type = f"{trim_type} " if trim_type else "" 864 remove_chars = f"{remove_chars} " if remove_chars else "" 865 from_part = "FROM " if trim_type or remove_chars else "" 866 collation = f" COLLATE {collation}" if collation else "" 867 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 868 869 870def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 871 return self.func("STRPTIME", expression.this, self.format_time(expression)) 872 873 874def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 875 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 876 877 878def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 879 delim, *rest_args = expression.expressions 880 return self.sql( 881 reduce( 882 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 883 rest_args, 884 ) 885 ) 886 887 888def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 889 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 890 if bad_args: 891 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 892 893 return self.func( 894 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 895 ) 896 897 898def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 899 bad_args = list( 900 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 901 ) 902 if bad_args: 903 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 904 905 return self.func( 906 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 907 ) 908 909 910def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 911 names = [] 912 for agg in aggregations: 913 if isinstance(agg, exp.Alias): 914 names.append(agg.alias) 915 else: 916 """ 917 This case corresponds to aggregations without aliases being used as suffixes 918 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 919 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 920 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 921 """ 922 agg_all_unquoted = agg.transform( 923 lambda node: ( 924 exp.Identifier(this=node.name, quoted=False) 925 if isinstance(node, exp.Identifier) 926 else node 927 ) 928 ) 929 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 930 931 return names 932 933 934def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 935 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 936 937 938# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 939def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 940 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 941 942 943def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 944 return self.func("MAX", expression.this) 945 946 947def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 948 a = self.sql(expression.left) 949 b = self.sql(expression.right) 950 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 951 952 953def is_parse_json(expression: exp.Expression) -> bool: 954 return isinstance(expression, exp.ParseJSON) or ( 955 isinstance(expression, exp.Cast) and expression.is_type("json") 956 ) 957 958 959def isnull_to_is_null(args: t.List) -> exp.Expression: 960 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 961 962 963def generatedasidentitycolumnconstraint_sql( 964 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 965) -> str: 966 start = self.sql(expression, "start") or "1" 967 increment = self.sql(expression, "increment") or "1" 968 return f"IDENTITY({start}, {increment})" 969 970 971def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 972 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 973 if expression.args.get("count"): 974 self.unsupported(f"Only two arguments are supported in function {name}.") 975 976 return self.func(name, expression.this, expression.expression) 977 978 return _arg_max_or_min_sql 979 980 981def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 982 this = expression.this.copy() 983 984 return_type = expression.return_type 985 if return_type.is_type(exp.DataType.Type.DATE): 986 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 987 # can truncate timestamp strings, because some dialects can't cast them to DATE 988 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 989 990 expression.this.replace(exp.cast(this, return_type)) 991 return expression 992 993 994def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 995 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 996 if cast and isinstance(expression, exp.TsOrDsAdd): 997 expression = ts_or_ds_add_cast(expression) 998 999 return self.func( 1000 name, 1001 exp.var(expression.text("unit").upper() or "DAY"), 1002 expression.expression, 1003 expression.this, 1004 ) 1005 1006 return _delta_sql 1007 1008 1009def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1010 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1011 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1012 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1013 1014 return self.sql(exp.cast(minus_one_day, "date")) 1015 1016 1017def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1018 """Remove table refs from columns in when statements.""" 1019 alias = expression.this.args.get("alias") 1020 1021 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1022 return self.dialect.normalize_identifier(identifier).name if identifier else None 1023 1024 targets = {normalize(expression.this.this)} 1025 1026 if alias: 1027 targets.add(normalize(alias.this)) 1028 1029 for when in expression.expressions: 1030 when.transform( 1031 lambda node: ( 1032 exp.column(node.this) 1033 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1034 else node 1035 ), 1036 copy=False, 1037 ) 1038 1039 return self.merge_sql(expression) 1040 1041 1042def build_json_extract_path( 1043 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1044) -> t.Callable[[t.List], F]: 1045 def _builder(args: t.List) -> F: 1046 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1047 for arg in args[1:]: 1048 if not isinstance(arg, exp.Literal): 1049 # We use the fallback parser because we can't really transpile non-literals safely 1050 return expr_type.from_arg_list(args) 1051 1052 text = arg.name 1053 if is_int(text): 1054 index = int(text) 1055 segments.append( 1056 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1057 ) 1058 else: 1059 segments.append(exp.JSONPathKey(this=text)) 1060 1061 # This is done to avoid failing in the expression validator due to the arg count 1062 del args[2:] 1063 return expr_type( 1064 this=seq_get(args, 0), 1065 expression=exp.JSONPath(expressions=segments), 1066 only_json_types=arrow_req_json_type, 1067 ) 1068 1069 return _builder 1070 1071 1072def json_extract_segments( 1073 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1074) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1075 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1076 path = expression.expression 1077 if not isinstance(path, exp.JSONPath): 1078 return rename_func(name)(self, expression) 1079 1080 segments = [] 1081 for segment in path.expressions: 1082 path = self.sql(segment) 1083 if path: 1084 if isinstance(segment, exp.JSONPathPart) and ( 1085 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1086 ): 1087 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1088 1089 segments.append(path) 1090 1091 if op: 1092 return f" {op} ".join([self.sql(expression.this), *segments]) 1093 return self.func(name, expression.this, *segments) 1094 1095 return _json_extract_segments 1096 1097 1098def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1099 if isinstance(expression.this, exp.JSONPathWildcard): 1100 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1101 1102 return expression.name 1103 1104 1105def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1106 cond = expression.expression 1107 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1108 alias = cond.expressions[0] 1109 cond = cond.this 1110 elif isinstance(cond, exp.Predicate): 1111 alias = "_u" 1112 else: 1113 self.unsupported("Unsupported filter condition") 1114 return "" 1115 1116 unnest = exp.Unnest(expressions=[expression.this]) 1117 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1118 return self.sql(exp.Array(expressions=[filtered])) 1119 1120 1121def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: 1122 return self.func( 1123 "TO_NUMBER", 1124 expression.this, 1125 expression.args.get("format"), 1126 expression.args.get("nlsparam"), 1127 )
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
181class Dialect(metaclass=_Dialect): 182 INDEX_OFFSET = 0 183 """The base index offset for arrays.""" 184 185 WEEK_OFFSET = 0 186 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 187 188 UNNEST_COLUMN_ONLY = False 189 """Whether `UNNEST` table aliases are treated as column aliases.""" 190 191 ALIAS_POST_TABLESAMPLE = False 192 """Whether the table alias comes after tablesample.""" 193 194 TABLESAMPLE_SIZE_IS_PERCENT = False 195 """Whether a size in the table sample clause represents percentage.""" 196 197 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 198 """Specifies the strategy according to which identifiers should be normalized.""" 199 200 IDENTIFIERS_CAN_START_WITH_DIGIT = False 201 """Whether an unquoted identifier can start with a digit.""" 202 203 DPIPE_IS_STRING_CONCAT = True 204 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 205 206 STRICT_STRING_CONCAT = False 207 """Whether `CONCAT`'s arguments must be strings.""" 208 209 SUPPORTS_USER_DEFINED_TYPES = True 210 """Whether user-defined data types are supported.""" 211 212 SUPPORTS_SEMI_ANTI_JOIN = True 213 """Whether `SEMI` or `ANTI` joins are supported.""" 214 215 NORMALIZE_FUNCTIONS: bool | str = "upper" 216 """ 217 Determines how function names are going to be normalized. 218 Possible values: 219 "upper" or True: Convert names to uppercase. 220 "lower": Convert names to lowercase. 221 False: Disables function name normalization. 222 """ 223 224 LOG_BASE_FIRST: t.Optional[bool] = True 225 """ 226 Whether the base comes first in the `LOG` function. 227 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 228 """ 229 230 NULL_ORDERING = "nulls_are_small" 231 """ 232 Default `NULL` ordering method to use if not explicitly set. 233 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 234 """ 235 236 TYPED_DIVISION = False 237 """ 238 Whether the behavior of `a / b` depends on the types of `a` and `b`. 239 False means `a / b` is always float division. 240 True means `a / b` is integer division if both `a` and `b` are integers. 241 """ 242 243 SAFE_DIVISION = False 244 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 245 246 CONCAT_COALESCE = False 247 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 248 249 DATE_FORMAT = "'%Y-%m-%d'" 250 DATEINT_FORMAT = "'%Y%m%d'" 251 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 252 253 TIME_MAPPING: t.Dict[str, str] = {} 254 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 255 256 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 257 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 258 FORMAT_MAPPING: t.Dict[str, str] = {} 259 """ 260 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 261 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 262 """ 263 264 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 265 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 266 267 PSEUDOCOLUMNS: t.Set[str] = set() 268 """ 269 Columns that are auto-generated by the engine corresponding to this dialect. 270 For example, such columns may be excluded from `SELECT *` queries. 271 """ 272 273 PREFER_CTE_ALIAS_COLUMN = False 274 """ 275 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 276 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 277 any projection aliases in the subquery. 278 279 For example, 280 WITH y(c) AS ( 281 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 282 ) SELECT c FROM y; 283 284 will be rewritten as 285 286 WITH y(c) AS ( 287 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 288 ) SELECT c FROM y; 289 """ 290 291 # --- Autofilled --- 292 293 tokenizer_class = Tokenizer 294 parser_class = Parser 295 generator_class = Generator 296 297 # A trie of the time_mapping keys 298 TIME_TRIE: t.Dict = {} 299 FORMAT_TRIE: t.Dict = {} 300 301 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 302 INVERSE_TIME_TRIE: t.Dict = {} 303 304 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 305 306 # Delimiters for string literals and identifiers 307 QUOTE_START = "'" 308 QUOTE_END = "'" 309 IDENTIFIER_START = '"' 310 IDENTIFIER_END = '"' 311 312 # Delimiters for bit, hex, byte and unicode literals 313 BIT_START: t.Optional[str] = None 314 BIT_END: t.Optional[str] = None 315 HEX_START: t.Optional[str] = None 316 HEX_END: t.Optional[str] = None 317 BYTE_START: t.Optional[str] = None 318 BYTE_END: t.Optional[str] = None 319 UNICODE_START: t.Optional[str] = None 320 UNICODE_END: t.Optional[str] = None 321 322 @classmethod 323 def get_or_raise(cls, dialect: DialectType) -> Dialect: 324 """ 325 Look up a dialect in the global dialect registry and return it if it exists. 326 327 Args: 328 dialect: The target dialect. If this is a string, it can be optionally followed by 329 additional key-value pairs that are separated by commas and are used to specify 330 dialect settings, such as whether the dialect's identifiers are case-sensitive. 331 332 Example: 333 >>> dialect = dialect_class = get_or_raise("duckdb") 334 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 335 336 Returns: 337 The corresponding Dialect instance. 338 """ 339 340 if not dialect: 341 return cls() 342 if isinstance(dialect, _Dialect): 343 return dialect() 344 if isinstance(dialect, Dialect): 345 return dialect 346 if isinstance(dialect, str): 347 try: 348 dialect_name, *kv_pairs = dialect.split(",") 349 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 350 except ValueError: 351 raise ValueError( 352 f"Invalid dialect format: '{dialect}'. " 353 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 354 ) 355 356 result = cls.get(dialect_name.strip()) 357 if not result: 358 from difflib import get_close_matches 359 360 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 361 if similar: 362 similar = f" Did you mean {similar}?" 363 364 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 365 366 return result(**kwargs) 367 368 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 369 370 @classmethod 371 def format_time( 372 cls, expression: t.Optional[str | exp.Expression] 373 ) -> t.Optional[exp.Expression]: 374 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 375 if isinstance(expression, str): 376 return exp.Literal.string( 377 # the time formats are quoted 378 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 379 ) 380 381 if expression and expression.is_string: 382 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 383 384 return expression 385 386 def __init__(self, **kwargs) -> None: 387 normalization_strategy = kwargs.get("normalization_strategy") 388 389 if normalization_strategy is None: 390 self.normalization_strategy = self.NORMALIZATION_STRATEGY 391 else: 392 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 393 394 def __eq__(self, other: t.Any) -> bool: 395 # Does not currently take dialect state into account 396 return type(self) == other 397 398 def __hash__(self) -> int: 399 # Does not currently take dialect state into account 400 return hash(type(self)) 401 402 def normalize_identifier(self, expression: E) -> E: 403 """ 404 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 405 406 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 407 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 408 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 409 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 410 411 There are also dialects like Spark, which are case-insensitive even when quotes are 412 present, and dialects like MySQL, whose resolution rules match those employed by the 413 underlying operating system, for example they may always be case-sensitive in Linux. 414 415 Finally, the normalization behavior of some engines can even be controlled through flags, 416 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 417 418 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 419 that it can analyze queries in the optimizer and successfully capture their semantics. 420 """ 421 if ( 422 isinstance(expression, exp.Identifier) 423 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 424 and ( 425 not expression.quoted 426 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 427 ) 428 ): 429 expression.set( 430 "this", 431 ( 432 expression.this.upper() 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else expression.this.lower() 435 ), 436 ) 437 438 return expression 439 440 def case_sensitive(self, text: str) -> bool: 441 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 442 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 443 return False 444 445 unsafe = ( 446 str.islower 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else str.isupper 449 ) 450 return any(unsafe(char) for char in text) 451 452 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 453 """Checks if text can be identified given an identify option. 454 455 Args: 456 text: The text to check. 457 identify: 458 `"always"` or `True`: Always returns `True`. 459 `"safe"`: Only returns `True` if the identifier is case-insensitive. 460 461 Returns: 462 Whether the given text can be identified. 463 """ 464 if identify is True or identify == "always": 465 return True 466 467 if identify == "safe": 468 return not self.case_sensitive(text) 469 470 return False 471 472 def quote_identifier(self, expression: E, identify: bool = True) -> E: 473 """ 474 Adds quotes to a given identifier. 475 476 Args: 477 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 478 identify: If set to `False`, the quotes will only be added if the identifier is deemed 479 "unsafe", with respect to its characters and this dialect's normalization strategy. 480 """ 481 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 482 name = expression.this 483 expression.set( 484 "quoted", 485 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 486 ) 487 488 return expression 489 490 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 491 if isinstance(path, exp.Literal): 492 path_text = path.name 493 if path.is_number: 494 path_text = f"[{path_text}]" 495 496 try: 497 return parse_json_path(path_text) 498 except ParseError as e: 499 logger.warning(f"Invalid JSON path syntax. {str(e)}") 500 501 return path 502 503 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 504 return self.parser(**opts).parse(self.tokenize(sql), sql) 505 506 def parse_into( 507 self, expression_type: exp.IntoType, sql: str, **opts 508 ) -> t.List[t.Optional[exp.Expression]]: 509 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 510 511 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 512 return self.generator(**opts).generate(expression, copy=copy) 513 514 def transpile(self, sql: str, **opts) -> t.List[str]: 515 return [ 516 self.generate(expression, copy=False, **opts) if expression else "" 517 for expression in self.parse(sql) 518 ] 519 520 def tokenize(self, sql: str) -> t.List[Token]: 521 return self.tokenizer.tokenize(sql) 522 523 @property 524 def tokenizer(self) -> Tokenizer: 525 if not hasattr(self, "_tokenizer"): 526 self._tokenizer = self.tokenizer_class(dialect=self) 527 return self._tokenizer 528 529 def parser(self, **opts) -> Parser: 530 return self.parser_class(dialect=self, **opts) 531 532 def generator(self, **opts) -> Generator: 533 return self.generator_class(dialect=self, **opts)
386 def __init__(self, **kwargs) -> None: 387 normalization_strategy = kwargs.get("normalization_strategy") 388 389 if normalization_strategy is None: 390 self.normalization_strategy = self.NORMALIZATION_STRATEGY 391 else: 392 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
322 @classmethod 323 def get_or_raise(cls, dialect: DialectType) -> Dialect: 324 """ 325 Look up a dialect in the global dialect registry and return it if it exists. 326 327 Args: 328 dialect: The target dialect. If this is a string, it can be optionally followed by 329 additional key-value pairs that are separated by commas and are used to specify 330 dialect settings, such as whether the dialect's identifiers are case-sensitive. 331 332 Example: 333 >>> dialect = dialect_class = get_or_raise("duckdb") 334 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 335 336 Returns: 337 The corresponding Dialect instance. 338 """ 339 340 if not dialect: 341 return cls() 342 if isinstance(dialect, _Dialect): 343 return dialect() 344 if isinstance(dialect, Dialect): 345 return dialect 346 if isinstance(dialect, str): 347 try: 348 dialect_name, *kv_pairs = dialect.split(",") 349 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 350 except ValueError: 351 raise ValueError( 352 f"Invalid dialect format: '{dialect}'. " 353 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 354 ) 355 356 result = cls.get(dialect_name.strip()) 357 if not result: 358 from difflib import get_close_matches 359 360 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 361 if similar: 362 similar = f" Did you mean {similar}?" 363 364 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 365 366 return result(**kwargs) 367 368 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
370 @classmethod 371 def format_time( 372 cls, expression: t.Optional[str | exp.Expression] 373 ) -> t.Optional[exp.Expression]: 374 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 375 if isinstance(expression, str): 376 return exp.Literal.string( 377 # the time formats are quoted 378 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 379 ) 380 381 if expression and expression.is_string: 382 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 383 384 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
402 def normalize_identifier(self, expression: E) -> E: 403 """ 404 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 405 406 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 407 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 408 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 409 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 410 411 There are also dialects like Spark, which are case-insensitive even when quotes are 412 present, and dialects like MySQL, whose resolution rules match those employed by the 413 underlying operating system, for example they may always be case-sensitive in Linux. 414 415 Finally, the normalization behavior of some engines can even be controlled through flags, 416 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 417 418 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 419 that it can analyze queries in the optimizer and successfully capture their semantics. 420 """ 421 if ( 422 isinstance(expression, exp.Identifier) 423 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 424 and ( 425 not expression.quoted 426 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 427 ) 428 ): 429 expression.set( 430 "this", 431 ( 432 expression.this.upper() 433 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 434 else expression.this.lower() 435 ), 436 ) 437 438 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
440 def case_sensitive(self, text: str) -> bool: 441 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 442 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 443 return False 444 445 unsafe = ( 446 str.islower 447 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 448 else str.isupper 449 ) 450 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
452 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 453 """Checks if text can be identified given an identify option. 454 455 Args: 456 text: The text to check. 457 identify: 458 `"always"` or `True`: Always returns `True`. 459 `"safe"`: Only returns `True` if the identifier is case-insensitive. 460 461 Returns: 462 Whether the given text can be identified. 463 """ 464 if identify is True or identify == "always": 465 return True 466 467 if identify == "safe": 468 return not self.case_sensitive(text) 469 470 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
472 def quote_identifier(self, expression: E, identify: bool = True) -> E: 473 """ 474 Adds quotes to a given identifier. 475 476 Args: 477 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 478 identify: If set to `False`, the quotes will only be added if the identifier is deemed 479 "unsafe", with respect to its characters and this dialect's normalization strategy. 480 """ 481 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 482 name = expression.this 483 expression.set( 484 "quoted", 485 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 486 ) 487 488 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
490 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 491 if isinstance(path, exp.Literal): 492 path_text = path.name 493 if path.is_number: 494 path_text = f"[{path_text}]" 495 496 try: 497 return parse_json_path(path_text) 498 except ParseError as e: 499 logger.warning(f"Invalid JSON path syntax. {str(e)}") 500 501 return path
549def if_sql( 550 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 551) -> t.Callable[[Generator, exp.If], str]: 552 def _if_sql(self: Generator, expression: exp.If) -> str: 553 return self.func( 554 name, 555 expression.this, 556 expression.args.get("true"), 557 expression.args.get("false") or false_value, 558 ) 559 560 return _if_sql
563def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 564 this = expression.this 565 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 566 this.replace(exp.cast(this, "json")) 567 568 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
625def str_position_sql( 626 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 627) -> str: 628 this = self.sql(expression, "this") 629 substr = self.sql(expression, "substr") 630 position = self.sql(expression, "position") 631 instance = expression.args.get("instance") if generate_instance else None 632 position_offset = "" 633 634 if position: 635 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 636 this = self.func("SUBSTR", this, position) 637 position_offset = f" + {position} - 1" 638 639 return self.func("STRPOS", this, substr, instance) + position_offset
648def var_map_sql( 649 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 650) -> str: 651 keys = expression.args["keys"] 652 values = expression.args["values"] 653 654 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 655 self.unsupported("Cannot convert array columns into map.") 656 return self.func(map_func_name, keys, values) 657 658 args = [] 659 for key, value in zip(keys.expressions, values.expressions): 660 args.append(self.sql(key)) 661 args.append(self.sql(value)) 662 663 return self.func(map_func_name, *args)
666def build_formatted_time( 667 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 668) -> t.Callable[[t.List], E]: 669 """Helper used for time expressions. 670 671 Args: 672 exp_class: the expression class to instantiate. 673 dialect: target sql dialect. 674 default: the default format, True being time. 675 676 Returns: 677 A callable that can be used to return the appropriately formatted time expression. 678 """ 679 680 def _builder(args: t.List): 681 return exp_class( 682 this=seq_get(args, 0), 683 format=Dialect[dialect].format_time( 684 seq_get(args, 1) 685 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 686 ), 687 ) 688 689 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
692def time_format( 693 dialect: DialectType = None, 694) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 695 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 696 """ 697 Returns the time format for a given expression, unless it's equivalent 698 to the default time format of the dialect of interest. 699 """ 700 time_format = self.format_time(expression) 701 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 702 703 return _time_format
706def build_date_delta( 707 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 708) -> t.Callable[[t.List], E]: 709 def _builder(args: t.List) -> E: 710 unit_based = len(args) == 3 711 this = args[2] if unit_based else seq_get(args, 0) 712 unit = args[0] if unit_based else exp.Literal.string("DAY") 713 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 714 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 715 716 return _builder
719def build_date_delta_with_interval( 720 expression_class: t.Type[E], 721) -> t.Callable[[t.List], t.Optional[E]]: 722 def _builder(args: t.List) -> t.Optional[E]: 723 if len(args) < 2: 724 return None 725 726 interval = args[1] 727 728 if not isinstance(interval, exp.Interval): 729 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 730 731 expression = interval.this 732 if expression and expression.is_string: 733 expression = exp.Literal.number(expression.this) 734 735 return expression_class( 736 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 737 ) 738 739 return _builder
751def date_add_interval_sql( 752 data_type: str, kind: str 753) -> t.Callable[[Generator, exp.Expression], str]: 754 def func(self: Generator, expression: exp.Expression) -> str: 755 this = self.sql(expression, "this") 756 unit = expression.args.get("unit") 757 unit = exp.var(unit.name.upper() if unit else "DAY") 758 interval = exp.Interval(this=expression.expression, unit=unit) 759 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 760 761 return func
770def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 771 if not expression.expression: 772 from sqlglot.optimizer.annotate_types import annotate_types 773 774 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 775 return self.sql(exp.cast(expression.this, to=target_type)) 776 if expression.text("expression").lower() in TIMEZONES: 777 return self.sql( 778 exp.AtTimeZone( 779 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 780 zone=expression.expression, 781 ) 782 ) 783 return self.func("TIMESTAMP", expression.this, expression.expression)
824def encode_decode_sql( 825 self: Generator, expression: exp.Expression, name: str, replace: bool = True 826) -> str: 827 charset = expression.args.get("charset") 828 if charset and charset.name.lower() != "utf-8": 829 self.unsupported(f"Expected utf-8 character set, got {charset}.") 830 831 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
844def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 845 cond = expression.this 846 847 if isinstance(expression.this, exp.Distinct): 848 cond = expression.this.expressions[0] 849 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 850 851 return self.func("sum", exp.func("if", cond, 1, 0))
854def trim_sql(self: Generator, expression: exp.Trim) -> str: 855 target = self.sql(expression, "this") 856 trim_type = self.sql(expression, "position") 857 remove_chars = self.sql(expression, "expression") 858 collation = self.sql(expression, "collation") 859 860 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 861 if not remove_chars and not collation: 862 return self.trim_sql(expression) 863 864 trim_type = f"{trim_type} " if trim_type else "" 865 remove_chars = f"{remove_chars} " if remove_chars else "" 866 from_part = "FROM " if trim_type or remove_chars else "" 867 collation = f" COLLATE {collation}" if collation else "" 868 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
889def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 890 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 891 if bad_args: 892 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 893 894 return self.func( 895 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 896 )
899def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 900 bad_args = list( 901 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 902 ) 903 if bad_args: 904 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 905 906 return self.func( 907 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 908 )
911def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 912 names = [] 913 for agg in aggregations: 914 if isinstance(agg, exp.Alias): 915 names.append(agg.alias) 916 else: 917 """ 918 This case corresponds to aggregations without aliases being used as suffixes 919 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 920 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 921 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 922 """ 923 agg_all_unquoted = agg.transform( 924 lambda node: ( 925 exp.Identifier(this=node.name, quoted=False) 926 if isinstance(node, exp.Identifier) 927 else node 928 ) 929 ) 930 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 931 932 return names
972def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 973 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 974 if expression.args.get("count"): 975 self.unsupported(f"Only two arguments are supported in function {name}.") 976 977 return self.func(name, expression.this, expression.expression) 978 979 return _arg_max_or_min_sql
982def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 983 this = expression.this.copy() 984 985 return_type = expression.return_type 986 if return_type.is_type(exp.DataType.Type.DATE): 987 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 988 # can truncate timestamp strings, because some dialects can't cast them to DATE 989 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 990 991 expression.this.replace(exp.cast(this, return_type)) 992 return expression
995def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 996 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 997 if cast and isinstance(expression, exp.TsOrDsAdd): 998 expression = ts_or_ds_add_cast(expression) 999 1000 return self.func( 1001 name, 1002 exp.var(expression.text("unit").upper() or "DAY"), 1003 expression.expression, 1004 expression.this, 1005 ) 1006 1007 return _delta_sql
1010def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1011 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1012 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1013 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1014 1015 return self.sql(exp.cast(minus_one_day, "date"))
1018def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1019 """Remove table refs from columns in when statements.""" 1020 alias = expression.this.args.get("alias") 1021 1022 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1023 return self.dialect.normalize_identifier(identifier).name if identifier else None 1024 1025 targets = {normalize(expression.this.this)} 1026 1027 if alias: 1028 targets.add(normalize(alias.this)) 1029 1030 for when in expression.expressions: 1031 when.transform( 1032 lambda node: ( 1033 exp.column(node.this) 1034 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1035 else node 1036 ), 1037 copy=False, 1038 ) 1039 1040 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1043def build_json_extract_path( 1044 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1045) -> t.Callable[[t.List], F]: 1046 def _builder(args: t.List) -> F: 1047 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1048 for arg in args[1:]: 1049 if not isinstance(arg, exp.Literal): 1050 # We use the fallback parser because we can't really transpile non-literals safely 1051 return expr_type.from_arg_list(args) 1052 1053 text = arg.name 1054 if is_int(text): 1055 index = int(text) 1056 segments.append( 1057 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1058 ) 1059 else: 1060 segments.append(exp.JSONPathKey(this=text)) 1061 1062 # This is done to avoid failing in the expression validator due to the arg count 1063 del args[2:] 1064 return expr_type( 1065 this=seq_get(args, 0), 1066 expression=exp.JSONPath(expressions=segments), 1067 only_json_types=arrow_req_json_type, 1068 ) 1069 1070 return _builder
1073def json_extract_segments( 1074 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1075) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1076 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1077 path = expression.expression 1078 if not isinstance(path, exp.JSONPath): 1079 return rename_func(name)(self, expression) 1080 1081 segments = [] 1082 for segment in path.expressions: 1083 path = self.sql(segment) 1084 if path: 1085 if isinstance(segment, exp.JSONPathPart) and ( 1086 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1087 ): 1088 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1089 1090 segments.append(path) 1091 1092 if op: 1093 return f" {op} ".join([self.sql(expression.this), *segments]) 1094 return self.func(name, expression.this, *segments) 1095 1096 return _json_extract_segments
1106def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1107 cond = expression.expression 1108 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1109 alias = cond.expressions[0] 1110 cond = cond.this 1111 elif isinstance(cond, exp.Predicate): 1112 alias = "_u" 1113 else: 1114 self.unsupported("Unsupported filter condition") 1115 return "" 1116 1117 unnest = exp.Unnest(expressions=[expression.this]) 1118 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1119 return self.sql(exp.Array(expressions=[filtered]))