sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 BIGQUERY = "bigquery" 35 CLICKHOUSE = "clickhouse" 36 DATABRICKS = "databricks" 37 DORIS = "doris" 38 DRILL = "drill" 39 DUCKDB = "duckdb" 40 HIVE = "hive" 41 MYSQL = "mysql" 42 ORACLE = "oracle" 43 POSTGRES = "postgres" 44 PRESTO = "presto" 45 PRQL = "prql" 46 REDSHIFT = "redshift" 47 SNOWFLAKE = "snowflake" 48 SPARK = "spark" 49 SPARK2 = "spark2" 50 SQLITE = "sqlite" 51 STARROCKS = "starrocks" 52 TABLEAU = "tableau" 53 TERADATA = "teradata" 54 TRINO = "trino" 55 TSQL = "tsql" 56 57 58class NormalizationStrategy(str, AutoName): 59 """Specifies the strategy according to which identifiers should be normalized.""" 60 61 LOWERCASE = auto() 62 """Unquoted identifiers are lowercased.""" 63 64 UPPERCASE = auto() 65 """Unquoted identifiers are uppercased.""" 66 67 CASE_SENSITIVE = auto() 68 """Always case-sensitive, regardless of quotes.""" 69 70 CASE_INSENSITIVE = auto() 71 """Always case-insensitive, regardless of quotes.""" 72 73 74class _Dialect(type): 75 classes: t.Dict[str, t.Type[Dialect]] = {} 76 77 def __eq__(cls, other: t.Any) -> bool: 78 if cls is other: 79 return True 80 if isinstance(other, str): 81 return cls is cls.get(other) 82 if isinstance(other, Dialect): 83 return cls is type(other) 84 85 return False 86 87 def __hash__(cls) -> int: 88 return hash(cls.__name__.lower()) 89 90 @classmethod 91 def __getitem__(cls, key: str) -> t.Type[Dialect]: 92 return cls.classes[key] 93 94 @classmethod 95 def get( 96 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 97 ) -> t.Optional[t.Type[Dialect]]: 98 return cls.classes.get(key, default) 99 100 def __new__(cls, clsname, bases, attrs): 101 klass = super().__new__(cls, clsname, bases, attrs) 102 enum = Dialects.__members__.get(clsname.upper()) 103 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 104 105 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 106 klass.FORMAT_TRIE = ( 107 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 108 ) 109 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 110 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 111 112 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 113 114 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 115 klass.parser_class = getattr(klass, "Parser", Parser) 116 klass.generator_class = getattr(klass, "Generator", Generator) 117 118 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 119 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 120 klass.tokenizer_class._IDENTIFIERS.items() 121 )[0] 122 123 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 124 return next( 125 ( 126 (s, e) 127 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 128 if t == token_type 129 ), 130 (None, None), 131 ) 132 133 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 134 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 135 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 136 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 137 138 if enum not in ("", "bigquery"): 139 klass.generator_class.SELECT_KINDS = () 140 141 if enum not in ("", "databricks", "hive", "spark", "spark2"): 142 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 143 for modifier in ("cluster", "distribute", "sort"): 144 modifier_transforms.pop(modifier, None) 145 146 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 147 148 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 149 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 150 TokenType.ANTI, 151 TokenType.SEMI, 152 } 153 154 return klass 155 156 157class Dialect(metaclass=_Dialect): 158 INDEX_OFFSET = 0 159 """The base index offset for arrays.""" 160 161 WEEK_OFFSET = 0 162 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 163 164 UNNEST_COLUMN_ONLY = False 165 """Whether `UNNEST` table aliases are treated as column aliases.""" 166 167 ALIAS_POST_TABLESAMPLE = False 168 """Whether the table alias comes after tablesample.""" 169 170 TABLESAMPLE_SIZE_IS_PERCENT = False 171 """Whether a size in the table sample clause represents percentage.""" 172 173 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 174 """Specifies the strategy according to which identifiers should be normalized.""" 175 176 IDENTIFIERS_CAN_START_WITH_DIGIT = False 177 """Whether an unquoted identifier can start with a digit.""" 178 179 DPIPE_IS_STRING_CONCAT = True 180 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 181 182 STRICT_STRING_CONCAT = False 183 """Whether `CONCAT`'s arguments must be strings.""" 184 185 SUPPORTS_USER_DEFINED_TYPES = True 186 """Whether user-defined data types are supported.""" 187 188 SUPPORTS_SEMI_ANTI_JOIN = True 189 """Whether `SEMI` or `ANTI` joins are supported.""" 190 191 NORMALIZE_FUNCTIONS: bool | str = "upper" 192 """ 193 Determines how function names are going to be normalized. 194 Possible values: 195 "upper" or True: Convert names to uppercase. 196 "lower": Convert names to lowercase. 197 False: Disables function name normalization. 198 """ 199 200 LOG_BASE_FIRST = True 201 """Whether the base comes first in the `LOG` function.""" 202 203 NULL_ORDERING = "nulls_are_small" 204 """ 205 Default `NULL` ordering method to use if not explicitly set. 206 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 207 """ 208 209 TYPED_DIVISION = False 210 """ 211 Whether the behavior of `a / b` depends on the types of `a` and `b`. 212 False means `a / b` is always float division. 213 True means `a / b` is integer division if both `a` and `b` are integers. 214 """ 215 216 SAFE_DIVISION = False 217 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 218 219 CONCAT_COALESCE = False 220 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 221 222 DATE_FORMAT = "'%Y-%m-%d'" 223 DATEINT_FORMAT = "'%Y%m%d'" 224 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 225 226 TIME_MAPPING: t.Dict[str, str] = {} 227 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 228 229 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 230 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 231 FORMAT_MAPPING: t.Dict[str, str] = {} 232 """ 233 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 234 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 235 """ 236 237 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 238 """Mapping of an unescaped escape sequence to the corresponding character.""" 239 240 PSEUDOCOLUMNS: t.Set[str] = set() 241 """ 242 Columns that are auto-generated by the engine corresponding to this dialect. 243 For example, such columns may be excluded from `SELECT *` queries. 244 """ 245 246 PREFER_CTE_ALIAS_COLUMN = False 247 """ 248 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 249 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 250 any projection aliases in the subquery. 251 252 For example, 253 WITH y(c) AS ( 254 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 255 ) SELECT c FROM y; 256 257 will be rewritten as 258 259 WITH y(c) AS ( 260 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 261 ) SELECT c FROM y; 262 """ 263 264 # --- Autofilled --- 265 266 tokenizer_class = Tokenizer 267 parser_class = Parser 268 generator_class = Generator 269 270 # A trie of the time_mapping keys 271 TIME_TRIE: t.Dict = {} 272 FORMAT_TRIE: t.Dict = {} 273 274 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 275 INVERSE_TIME_TRIE: t.Dict = {} 276 277 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 278 279 # Delimiters for string literals and identifiers 280 QUOTE_START = "'" 281 QUOTE_END = "'" 282 IDENTIFIER_START = '"' 283 IDENTIFIER_END = '"' 284 285 # Delimiters for bit, hex, byte and unicode literals 286 BIT_START: t.Optional[str] = None 287 BIT_END: t.Optional[str] = None 288 HEX_START: t.Optional[str] = None 289 HEX_END: t.Optional[str] = None 290 BYTE_START: t.Optional[str] = None 291 BYTE_END: t.Optional[str] = None 292 UNICODE_START: t.Optional[str] = None 293 UNICODE_END: t.Optional[str] = None 294 295 @classmethod 296 def get_or_raise(cls, dialect: DialectType) -> Dialect: 297 """ 298 Look up a dialect in the global dialect registry and return it if it exists. 299 300 Args: 301 dialect: The target dialect. If this is a string, it can be optionally followed by 302 additional key-value pairs that are separated by commas and are used to specify 303 dialect settings, such as whether the dialect's identifiers are case-sensitive. 304 305 Example: 306 >>> dialect = dialect_class = get_or_raise("duckdb") 307 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 308 309 Returns: 310 The corresponding Dialect instance. 311 """ 312 313 if not dialect: 314 return cls() 315 if isinstance(dialect, _Dialect): 316 return dialect() 317 if isinstance(dialect, Dialect): 318 return dialect 319 if isinstance(dialect, str): 320 try: 321 dialect_name, *kv_pairs = dialect.split(",") 322 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 323 except ValueError: 324 raise ValueError( 325 f"Invalid dialect format: '{dialect}'. " 326 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 327 ) 328 329 result = cls.get(dialect_name.strip()) 330 if not result: 331 from difflib import get_close_matches 332 333 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 334 if similar: 335 similar = f" Did you mean {similar}?" 336 337 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 338 339 return result(**kwargs) 340 341 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 342 343 @classmethod 344 def format_time( 345 cls, expression: t.Optional[str | exp.Expression] 346 ) -> t.Optional[exp.Expression]: 347 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 348 if isinstance(expression, str): 349 return exp.Literal.string( 350 # the time formats are quoted 351 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 352 ) 353 354 if expression and expression.is_string: 355 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 356 357 return expression 358 359 def __init__(self, **kwargs) -> None: 360 normalization_strategy = kwargs.get("normalization_strategy") 361 362 if normalization_strategy is None: 363 self.normalization_strategy = self.NORMALIZATION_STRATEGY 364 else: 365 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 366 367 def __eq__(self, other: t.Any) -> bool: 368 # Does not currently take dialect state into account 369 return type(self) == other 370 371 def __hash__(self) -> int: 372 # Does not currently take dialect state into account 373 return hash(type(self)) 374 375 def normalize_identifier(self, expression: E) -> E: 376 """ 377 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 378 379 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 380 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 381 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 382 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 383 384 There are also dialects like Spark, which are case-insensitive even when quotes are 385 present, and dialects like MySQL, whose resolution rules match those employed by the 386 underlying operating system, for example they may always be case-sensitive in Linux. 387 388 Finally, the normalization behavior of some engines can even be controlled through flags, 389 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 390 391 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 392 that it can analyze queries in the optimizer and successfully capture their semantics. 393 """ 394 if ( 395 isinstance(expression, exp.Identifier) 396 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 397 and ( 398 not expression.quoted 399 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 400 ) 401 ): 402 expression.set( 403 "this", 404 ( 405 expression.this.upper() 406 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 407 else expression.this.lower() 408 ), 409 ) 410 411 return expression 412 413 def case_sensitive(self, text: str) -> bool: 414 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 415 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 416 return False 417 418 unsafe = ( 419 str.islower 420 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 421 else str.isupper 422 ) 423 return any(unsafe(char) for char in text) 424 425 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 426 """Checks if text can be identified given an identify option. 427 428 Args: 429 text: The text to check. 430 identify: 431 `"always"` or `True`: Always returns `True`. 432 `"safe"`: Only returns `True` if the identifier is case-insensitive. 433 434 Returns: 435 Whether the given text can be identified. 436 """ 437 if identify is True or identify == "always": 438 return True 439 440 if identify == "safe": 441 return not self.case_sensitive(text) 442 443 return False 444 445 def quote_identifier(self, expression: E, identify: bool = True) -> E: 446 """ 447 Adds quotes to a given identifier. 448 449 Args: 450 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 451 identify: If set to `False`, the quotes will only be added if the identifier is deemed 452 "unsafe", with respect to its characters and this dialect's normalization strategy. 453 """ 454 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 455 name = expression.this 456 expression.set( 457 "quoted", 458 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 459 ) 460 461 return expression 462 463 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 464 if isinstance(path, exp.Literal): 465 path_text = path.name 466 if path.is_number: 467 path_text = f"[{path_text}]" 468 469 try: 470 return parse_json_path(path_text) 471 except ParseError as e: 472 logger.warning(f"Invalid JSON path syntax. {str(e)}") 473 474 return path 475 476 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 477 return self.parser(**opts).parse(self.tokenize(sql), sql) 478 479 def parse_into( 480 self, expression_type: exp.IntoType, sql: str, **opts 481 ) -> t.List[t.Optional[exp.Expression]]: 482 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 483 484 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 485 return self.generator(**opts).generate(expression, copy=copy) 486 487 def transpile(self, sql: str, **opts) -> t.List[str]: 488 return [ 489 self.generate(expression, copy=False, **opts) if expression else "" 490 for expression in self.parse(sql) 491 ] 492 493 def tokenize(self, sql: str) -> t.List[Token]: 494 return self.tokenizer.tokenize(sql) 495 496 @property 497 def tokenizer(self) -> Tokenizer: 498 if not hasattr(self, "_tokenizer"): 499 self._tokenizer = self.tokenizer_class(dialect=self) 500 return self._tokenizer 501 502 def parser(self, **opts) -> Parser: 503 return self.parser_class(dialect=self, **opts) 504 505 def generator(self, **opts) -> Generator: 506 return self.generator_class(dialect=self, **opts) 507 508 509DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 510 511 512def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 513 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 514 515 516def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 517 if expression.args.get("accuracy"): 518 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 519 return self.func("APPROX_COUNT_DISTINCT", expression.this) 520 521 522def if_sql( 523 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 524) -> t.Callable[[Generator, exp.If], str]: 525 def _if_sql(self: Generator, expression: exp.If) -> str: 526 return self.func( 527 name, 528 expression.this, 529 expression.args.get("true"), 530 expression.args.get("false") or false_value, 531 ) 532 533 return _if_sql 534 535 536def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 537 this = expression.this 538 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 539 this.replace(exp.cast(this, "json")) 540 541 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 542 543 544def inline_array_sql(self: Generator, expression: exp.Array) -> str: 545 return f"[{self.expressions(expression, flat=True)}]" 546 547 548def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 549 return self.like_sql( 550 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 551 ) 552 553 554def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 555 zone = self.sql(expression, "this") 556 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 557 558 559def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 560 if expression.args.get("recursive"): 561 self.unsupported("Recursive CTEs are unsupported") 562 expression.args["recursive"] = False 563 return self.with_sql(expression) 564 565 566def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 567 n = self.sql(expression, "this") 568 d = self.sql(expression, "expression") 569 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 570 571 572def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 573 self.unsupported("TABLESAMPLE unsupported") 574 return self.sql(expression.this) 575 576 577def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 578 self.unsupported("PIVOT unsupported") 579 return "" 580 581 582def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 583 return self.cast_sql(expression) 584 585 586def no_comment_column_constraint_sql( 587 self: Generator, expression: exp.CommentColumnConstraint 588) -> str: 589 self.unsupported("CommentColumnConstraint unsupported") 590 return "" 591 592 593def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 594 self.unsupported("MAP_FROM_ENTRIES unsupported") 595 return "" 596 597 598def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 599 this = self.sql(expression, "this") 600 substr = self.sql(expression, "substr") 601 position = self.sql(expression, "position") 602 if position: 603 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 604 return f"STRPOS({this}, {substr})" 605 606 607def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 608 return ( 609 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 610 ) 611 612 613def var_map_sql( 614 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 615) -> str: 616 keys = expression.args["keys"] 617 values = expression.args["values"] 618 619 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 620 self.unsupported("Cannot convert array columns into map.") 621 return self.func(map_func_name, keys, values) 622 623 args = [] 624 for key, value in zip(keys.expressions, values.expressions): 625 args.append(self.sql(key)) 626 args.append(self.sql(value)) 627 628 return self.func(map_func_name, *args) 629 630 631def build_formatted_time( 632 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 633) -> t.Callable[[t.List], E]: 634 """Helper used for time expressions. 635 636 Args: 637 exp_class: the expression class to instantiate. 638 dialect: target sql dialect. 639 default: the default format, True being time. 640 641 Returns: 642 A callable that can be used to return the appropriately formatted time expression. 643 """ 644 645 def _builder(args: t.List): 646 return exp_class( 647 this=seq_get(args, 0), 648 format=Dialect[dialect].format_time( 649 seq_get(args, 1) 650 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 651 ), 652 ) 653 654 return _builder 655 656 657def time_format( 658 dialect: DialectType = None, 659) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 660 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 661 """ 662 Returns the time format for a given expression, unless it's equivalent 663 to the default time format of the dialect of interest. 664 """ 665 time_format = self.format_time(expression) 666 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 667 668 return _time_format 669 670 671def build_date_delta( 672 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 673) -> t.Callable[[t.List], E]: 674 def _builder(args: t.List) -> E: 675 unit_based = len(args) == 3 676 this = args[2] if unit_based else seq_get(args, 0) 677 unit = args[0] if unit_based else exp.Literal.string("DAY") 678 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 679 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 680 681 return _builder 682 683 684def build_date_delta_with_interval( 685 expression_class: t.Type[E], 686) -> t.Callable[[t.List], t.Optional[E]]: 687 def _builder(args: t.List) -> t.Optional[E]: 688 if len(args) < 2: 689 return None 690 691 interval = args[1] 692 693 if not isinstance(interval, exp.Interval): 694 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 695 696 expression = interval.this 697 if expression and expression.is_string: 698 expression = exp.Literal.number(expression.this) 699 700 return expression_class( 701 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 702 ) 703 704 return _builder 705 706 707def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 708 unit = seq_get(args, 0) 709 this = seq_get(args, 1) 710 711 if isinstance(this, exp.Cast) and this.is_type("date"): 712 return exp.DateTrunc(unit=unit, this=this) 713 return exp.TimestampTrunc(this=this, unit=unit) 714 715 716def date_add_interval_sql( 717 data_type: str, kind: str 718) -> t.Callable[[Generator, exp.Expression], str]: 719 def func(self: Generator, expression: exp.Expression) -> str: 720 this = self.sql(expression, "this") 721 unit = expression.args.get("unit") 722 unit = exp.var(unit.name.upper() if unit else "DAY") 723 interval = exp.Interval(this=expression.expression, unit=unit) 724 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 725 726 return func 727 728 729def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 730 return self.func( 731 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 732 ) 733 734 735def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 736 if not expression.expression: 737 from sqlglot.optimizer.annotate_types import annotate_types 738 739 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 740 return self.sql(exp.cast(expression.this, to=target_type)) 741 if expression.text("expression").lower() in TIMEZONES: 742 return self.sql( 743 exp.AtTimeZone( 744 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 745 zone=expression.expression, 746 ) 747 ) 748 return self.func("TIMESTAMP", expression.this, expression.expression) 749 750 751def locate_to_strposition(args: t.List) -> exp.Expression: 752 return exp.StrPosition( 753 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 754 ) 755 756 757def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 758 return self.func( 759 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 760 ) 761 762 763def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 764 return self.sql( 765 exp.Substring( 766 this=expression.this, start=exp.Literal.number(1), length=expression.expression 767 ) 768 ) 769 770 771def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 772 return self.sql( 773 exp.Substring( 774 this=expression.this, 775 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 776 ) 777 ) 778 779 780def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 781 return self.sql(exp.cast(expression.this, "timestamp")) 782 783 784def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 785 return self.sql(exp.cast(expression.this, "date")) 786 787 788# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 789def encode_decode_sql( 790 self: Generator, expression: exp.Expression, name: str, replace: bool = True 791) -> str: 792 charset = expression.args.get("charset") 793 if charset and charset.name.lower() != "utf-8": 794 self.unsupported(f"Expected utf-8 character set, got {charset}.") 795 796 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 797 798 799def min_or_least(self: Generator, expression: exp.Min) -> str: 800 name = "LEAST" if expression.expressions else "MIN" 801 return rename_func(name)(self, expression) 802 803 804def max_or_greatest(self: Generator, expression: exp.Max) -> str: 805 name = "GREATEST" if expression.expressions else "MAX" 806 return rename_func(name)(self, expression) 807 808 809def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 810 cond = expression.this 811 812 if isinstance(expression.this, exp.Distinct): 813 cond = expression.this.expressions[0] 814 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 815 816 return self.func("sum", exp.func("if", cond, 1, 0)) 817 818 819def trim_sql(self: Generator, expression: exp.Trim) -> str: 820 target = self.sql(expression, "this") 821 trim_type = self.sql(expression, "position") 822 remove_chars = self.sql(expression, "expression") 823 collation = self.sql(expression, "collation") 824 825 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 826 if not remove_chars and not collation: 827 return self.trim_sql(expression) 828 829 trim_type = f"{trim_type} " if trim_type else "" 830 remove_chars = f"{remove_chars} " if remove_chars else "" 831 from_part = "FROM " if trim_type or remove_chars else "" 832 collation = f" COLLATE {collation}" if collation else "" 833 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 834 835 836def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 837 return self.func("STRPTIME", expression.this, self.format_time(expression)) 838 839 840def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 841 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 842 843 844def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 845 delim, *rest_args = expression.expressions 846 return self.sql( 847 reduce( 848 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 849 rest_args, 850 ) 851 ) 852 853 854def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 855 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 856 if bad_args: 857 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 858 859 return self.func( 860 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 861 ) 862 863 864def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 865 bad_args = list( 866 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 867 ) 868 if bad_args: 869 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 870 871 return self.func( 872 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 873 ) 874 875 876def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 877 names = [] 878 for agg in aggregations: 879 if isinstance(agg, exp.Alias): 880 names.append(agg.alias) 881 else: 882 """ 883 This case corresponds to aggregations without aliases being used as suffixes 884 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 885 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 886 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 887 """ 888 agg_all_unquoted = agg.transform( 889 lambda node: ( 890 exp.Identifier(this=node.name, quoted=False) 891 if isinstance(node, exp.Identifier) 892 else node 893 ) 894 ) 895 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 896 897 return names 898 899 900def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 901 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 902 903 904# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 905def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 906 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 907 908 909def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 910 return self.func("MAX", expression.this) 911 912 913def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 914 a = self.sql(expression.left) 915 b = self.sql(expression.right) 916 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 917 918 919def is_parse_json(expression: exp.Expression) -> bool: 920 return isinstance(expression, exp.ParseJSON) or ( 921 isinstance(expression, exp.Cast) and expression.is_type("json") 922 ) 923 924 925def isnull_to_is_null(args: t.List) -> exp.Expression: 926 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 927 928 929def generatedasidentitycolumnconstraint_sql( 930 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 931) -> str: 932 start = self.sql(expression, "start") or "1" 933 increment = self.sql(expression, "increment") or "1" 934 return f"IDENTITY({start}, {increment})" 935 936 937def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 938 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 939 if expression.args.get("count"): 940 self.unsupported(f"Only two arguments are supported in function {name}.") 941 942 return self.func(name, expression.this, expression.expression) 943 944 return _arg_max_or_min_sql 945 946 947def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 948 this = expression.this.copy() 949 950 return_type = expression.return_type 951 if return_type.is_type(exp.DataType.Type.DATE): 952 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 953 # can truncate timestamp strings, because some dialects can't cast them to DATE 954 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 955 956 expression.this.replace(exp.cast(this, return_type)) 957 return expression 958 959 960def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 961 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 962 if cast and isinstance(expression, exp.TsOrDsAdd): 963 expression = ts_or_ds_add_cast(expression) 964 965 return self.func( 966 name, 967 exp.var(expression.text("unit").upper() or "DAY"), 968 expression.expression, 969 expression.this, 970 ) 971 972 return _delta_sql 973 974 975def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 976 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 977 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 978 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 979 980 return self.sql(exp.cast(minus_one_day, "date")) 981 982 983def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 984 """Remove table refs from columns in when statements.""" 985 alias = expression.this.args.get("alias") 986 987 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 988 return self.dialect.normalize_identifier(identifier).name if identifier else None 989 990 targets = {normalize(expression.this.this)} 991 992 if alias: 993 targets.add(normalize(alias.this)) 994 995 for when in expression.expressions: 996 when.transform( 997 lambda node: ( 998 exp.column(node.this) 999 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1000 else node 1001 ), 1002 copy=False, 1003 ) 1004 1005 return self.merge_sql(expression) 1006 1007 1008def build_json_extract_path( 1009 expr_type: t.Type[F], zero_based_indexing: bool = True 1010) -> t.Callable[[t.List], F]: 1011 def _builder(args: t.List) -> F: 1012 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1013 for arg in args[1:]: 1014 if not isinstance(arg, exp.Literal): 1015 # We use the fallback parser because we can't really transpile non-literals safely 1016 return expr_type.from_arg_list(args) 1017 1018 text = arg.name 1019 if is_int(text): 1020 index = int(text) 1021 segments.append( 1022 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1023 ) 1024 else: 1025 segments.append(exp.JSONPathKey(this=text)) 1026 1027 # This is done to avoid failing in the expression validator due to the arg count 1028 del args[2:] 1029 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1030 1031 return _builder 1032 1033 1034def json_extract_segments( 1035 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1036) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1037 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1038 path = expression.expression 1039 if not isinstance(path, exp.JSONPath): 1040 return rename_func(name)(self, expression) 1041 1042 segments = [] 1043 for segment in path.expressions: 1044 path = self.sql(segment) 1045 if path: 1046 if isinstance(segment, exp.JSONPathPart) and ( 1047 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1048 ): 1049 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1050 1051 segments.append(path) 1052 1053 if op: 1054 return f" {op} ".join([self.sql(expression.this), *segments]) 1055 return self.func(name, expression.this, *segments) 1056 1057 return _json_extract_segments 1058 1059 1060def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1061 if isinstance(expression.this, exp.JSONPathWildcard): 1062 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1063 1064 return expression.name 1065 1066 1067def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1068 cond = expression.expression 1069 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1070 alias = cond.expressions[0] 1071 cond = cond.this 1072 elif isinstance(cond, exp.Predicate): 1073 alias = "_u" 1074 else: 1075 self.unsupported("Unsupported filter condition") 1076 return "" 1077 1078 unnest = exp.Unnest(expressions=[expression.this]) 1079 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1080 return self.sql(exp.Array(expressions=[filtered]))
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
158class Dialect(metaclass=_Dialect): 159 INDEX_OFFSET = 0 160 """The base index offset for arrays.""" 161 162 WEEK_OFFSET = 0 163 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 164 165 UNNEST_COLUMN_ONLY = False 166 """Whether `UNNEST` table aliases are treated as column aliases.""" 167 168 ALIAS_POST_TABLESAMPLE = False 169 """Whether the table alias comes after tablesample.""" 170 171 TABLESAMPLE_SIZE_IS_PERCENT = False 172 """Whether a size in the table sample clause represents percentage.""" 173 174 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 175 """Specifies the strategy according to which identifiers should be normalized.""" 176 177 IDENTIFIERS_CAN_START_WITH_DIGIT = False 178 """Whether an unquoted identifier can start with a digit.""" 179 180 DPIPE_IS_STRING_CONCAT = True 181 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 182 183 STRICT_STRING_CONCAT = False 184 """Whether `CONCAT`'s arguments must be strings.""" 185 186 SUPPORTS_USER_DEFINED_TYPES = True 187 """Whether user-defined data types are supported.""" 188 189 SUPPORTS_SEMI_ANTI_JOIN = True 190 """Whether `SEMI` or `ANTI` joins are supported.""" 191 192 NORMALIZE_FUNCTIONS: bool | str = "upper" 193 """ 194 Determines how function names are going to be normalized. 195 Possible values: 196 "upper" or True: Convert names to uppercase. 197 "lower": Convert names to lowercase. 198 False: Disables function name normalization. 199 """ 200 201 LOG_BASE_FIRST = True 202 """Whether the base comes first in the `LOG` function.""" 203 204 NULL_ORDERING = "nulls_are_small" 205 """ 206 Default `NULL` ordering method to use if not explicitly set. 207 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 208 """ 209 210 TYPED_DIVISION = False 211 """ 212 Whether the behavior of `a / b` depends on the types of `a` and `b`. 213 False means `a / b` is always float division. 214 True means `a / b` is integer division if both `a` and `b` are integers. 215 """ 216 217 SAFE_DIVISION = False 218 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 219 220 CONCAT_COALESCE = False 221 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 222 223 DATE_FORMAT = "'%Y-%m-%d'" 224 DATEINT_FORMAT = "'%Y%m%d'" 225 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 226 227 TIME_MAPPING: t.Dict[str, str] = {} 228 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 229 230 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 231 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 232 FORMAT_MAPPING: t.Dict[str, str] = {} 233 """ 234 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 235 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 236 """ 237 238 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 239 """Mapping of an unescaped escape sequence to the corresponding character.""" 240 241 PSEUDOCOLUMNS: t.Set[str] = set() 242 """ 243 Columns that are auto-generated by the engine corresponding to this dialect. 244 For example, such columns may be excluded from `SELECT *` queries. 245 """ 246 247 PREFER_CTE_ALIAS_COLUMN = False 248 """ 249 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 250 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 251 any projection aliases in the subquery. 252 253 For example, 254 WITH y(c) AS ( 255 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 256 ) SELECT c FROM y; 257 258 will be rewritten as 259 260 WITH y(c) AS ( 261 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 262 ) SELECT c FROM y; 263 """ 264 265 # --- Autofilled --- 266 267 tokenizer_class = Tokenizer 268 parser_class = Parser 269 generator_class = Generator 270 271 # A trie of the time_mapping keys 272 TIME_TRIE: t.Dict = {} 273 FORMAT_TRIE: t.Dict = {} 274 275 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 276 INVERSE_TIME_TRIE: t.Dict = {} 277 278 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 279 280 # Delimiters for string literals and identifiers 281 QUOTE_START = "'" 282 QUOTE_END = "'" 283 IDENTIFIER_START = '"' 284 IDENTIFIER_END = '"' 285 286 # Delimiters for bit, hex, byte and unicode literals 287 BIT_START: t.Optional[str] = None 288 BIT_END: t.Optional[str] = None 289 HEX_START: t.Optional[str] = None 290 HEX_END: t.Optional[str] = None 291 BYTE_START: t.Optional[str] = None 292 BYTE_END: t.Optional[str] = None 293 UNICODE_START: t.Optional[str] = None 294 UNICODE_END: t.Optional[str] = None 295 296 @classmethod 297 def get_or_raise(cls, dialect: DialectType) -> Dialect: 298 """ 299 Look up a dialect in the global dialect registry and return it if it exists. 300 301 Args: 302 dialect: The target dialect. If this is a string, it can be optionally followed by 303 additional key-value pairs that are separated by commas and are used to specify 304 dialect settings, such as whether the dialect's identifiers are case-sensitive. 305 306 Example: 307 >>> dialect = dialect_class = get_or_raise("duckdb") 308 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 309 310 Returns: 311 The corresponding Dialect instance. 312 """ 313 314 if not dialect: 315 return cls() 316 if isinstance(dialect, _Dialect): 317 return dialect() 318 if isinstance(dialect, Dialect): 319 return dialect 320 if isinstance(dialect, str): 321 try: 322 dialect_name, *kv_pairs = dialect.split(",") 323 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 324 except ValueError: 325 raise ValueError( 326 f"Invalid dialect format: '{dialect}'. " 327 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 328 ) 329 330 result = cls.get(dialect_name.strip()) 331 if not result: 332 from difflib import get_close_matches 333 334 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 335 if similar: 336 similar = f" Did you mean {similar}?" 337 338 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 339 340 return result(**kwargs) 341 342 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 343 344 @classmethod 345 def format_time( 346 cls, expression: t.Optional[str | exp.Expression] 347 ) -> t.Optional[exp.Expression]: 348 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 349 if isinstance(expression, str): 350 return exp.Literal.string( 351 # the time formats are quoted 352 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 353 ) 354 355 if expression and expression.is_string: 356 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 357 358 return expression 359 360 def __init__(self, **kwargs) -> None: 361 normalization_strategy = kwargs.get("normalization_strategy") 362 363 if normalization_strategy is None: 364 self.normalization_strategy = self.NORMALIZATION_STRATEGY 365 else: 366 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 367 368 def __eq__(self, other: t.Any) -> bool: 369 # Does not currently take dialect state into account 370 return type(self) == other 371 372 def __hash__(self) -> int: 373 # Does not currently take dialect state into account 374 return hash(type(self)) 375 376 def normalize_identifier(self, expression: E) -> E: 377 """ 378 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 379 380 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 381 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 382 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 383 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 384 385 There are also dialects like Spark, which are case-insensitive even when quotes are 386 present, and dialects like MySQL, whose resolution rules match those employed by the 387 underlying operating system, for example they may always be case-sensitive in Linux. 388 389 Finally, the normalization behavior of some engines can even be controlled through flags, 390 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 391 392 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 393 that it can analyze queries in the optimizer and successfully capture their semantics. 394 """ 395 if ( 396 isinstance(expression, exp.Identifier) 397 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 398 and ( 399 not expression.quoted 400 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 401 ) 402 ): 403 expression.set( 404 "this", 405 ( 406 expression.this.upper() 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else expression.this.lower() 409 ), 410 ) 411 412 return expression 413 414 def case_sensitive(self, text: str) -> bool: 415 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 416 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 417 return False 418 419 unsafe = ( 420 str.islower 421 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 422 else str.isupper 423 ) 424 return any(unsafe(char) for char in text) 425 426 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 427 """Checks if text can be identified given an identify option. 428 429 Args: 430 text: The text to check. 431 identify: 432 `"always"` or `True`: Always returns `True`. 433 `"safe"`: Only returns `True` if the identifier is case-insensitive. 434 435 Returns: 436 Whether the given text can be identified. 437 """ 438 if identify is True or identify == "always": 439 return True 440 441 if identify == "safe": 442 return not self.case_sensitive(text) 443 444 return False 445 446 def quote_identifier(self, expression: E, identify: bool = True) -> E: 447 """ 448 Adds quotes to a given identifier. 449 450 Args: 451 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 452 identify: If set to `False`, the quotes will only be added if the identifier is deemed 453 "unsafe", with respect to its characters and this dialect's normalization strategy. 454 """ 455 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 456 name = expression.this 457 expression.set( 458 "quoted", 459 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 460 ) 461 462 return expression 463 464 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 465 if isinstance(path, exp.Literal): 466 path_text = path.name 467 if path.is_number: 468 path_text = f"[{path_text}]" 469 470 try: 471 return parse_json_path(path_text) 472 except ParseError as e: 473 logger.warning(f"Invalid JSON path syntax. {str(e)}") 474 475 return path 476 477 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 478 return self.parser(**opts).parse(self.tokenize(sql), sql) 479 480 def parse_into( 481 self, expression_type: exp.IntoType, sql: str, **opts 482 ) -> t.List[t.Optional[exp.Expression]]: 483 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 484 485 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 486 return self.generator(**opts).generate(expression, copy=copy) 487 488 def transpile(self, sql: str, **opts) -> t.List[str]: 489 return [ 490 self.generate(expression, copy=False, **opts) if expression else "" 491 for expression in self.parse(sql) 492 ] 493 494 def tokenize(self, sql: str) -> t.List[Token]: 495 return self.tokenizer.tokenize(sql) 496 497 @property 498 def tokenizer(self) -> Tokenizer: 499 if not hasattr(self, "_tokenizer"): 500 self._tokenizer = self.tokenizer_class(dialect=self) 501 return self._tokenizer 502 503 def parser(self, **opts) -> Parser: 504 return self.parser_class(dialect=self, **opts) 505 506 def generator(self, **opts) -> Generator: 507 return self.generator_class(dialect=self, **opts)
360 def __init__(self, **kwargs) -> None: 361 normalization_strategy = kwargs.get("normalization_strategy") 362 363 if normalization_strategy is None: 364 self.normalization_strategy = self.NORMALIZATION_STRATEGY 365 else: 366 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
296 @classmethod 297 def get_or_raise(cls, dialect: DialectType) -> Dialect: 298 """ 299 Look up a dialect in the global dialect registry and return it if it exists. 300 301 Args: 302 dialect: The target dialect. If this is a string, it can be optionally followed by 303 additional key-value pairs that are separated by commas and are used to specify 304 dialect settings, such as whether the dialect's identifiers are case-sensitive. 305 306 Example: 307 >>> dialect = dialect_class = get_or_raise("duckdb") 308 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 309 310 Returns: 311 The corresponding Dialect instance. 312 """ 313 314 if not dialect: 315 return cls() 316 if isinstance(dialect, _Dialect): 317 return dialect() 318 if isinstance(dialect, Dialect): 319 return dialect 320 if isinstance(dialect, str): 321 try: 322 dialect_name, *kv_pairs = dialect.split(",") 323 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 324 except ValueError: 325 raise ValueError( 326 f"Invalid dialect format: '{dialect}'. " 327 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 328 ) 329 330 result = cls.get(dialect_name.strip()) 331 if not result: 332 from difflib import get_close_matches 333 334 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 335 if similar: 336 similar = f" Did you mean {similar}?" 337 338 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 339 340 return result(**kwargs) 341 342 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
344 @classmethod 345 def format_time( 346 cls, expression: t.Optional[str | exp.Expression] 347 ) -> t.Optional[exp.Expression]: 348 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 349 if isinstance(expression, str): 350 return exp.Literal.string( 351 # the time formats are quoted 352 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 353 ) 354 355 if expression and expression.is_string: 356 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 357 358 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
376 def normalize_identifier(self, expression: E) -> E: 377 """ 378 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 379 380 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 381 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 382 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 383 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 384 385 There are also dialects like Spark, which are case-insensitive even when quotes are 386 present, and dialects like MySQL, whose resolution rules match those employed by the 387 underlying operating system, for example they may always be case-sensitive in Linux. 388 389 Finally, the normalization behavior of some engines can even be controlled through flags, 390 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 391 392 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 393 that it can analyze queries in the optimizer and successfully capture their semantics. 394 """ 395 if ( 396 isinstance(expression, exp.Identifier) 397 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 398 and ( 399 not expression.quoted 400 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 401 ) 402 ): 403 expression.set( 404 "this", 405 ( 406 expression.this.upper() 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else expression.this.lower() 409 ), 410 ) 411 412 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
414 def case_sensitive(self, text: str) -> bool: 415 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 416 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 417 return False 418 419 unsafe = ( 420 str.islower 421 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 422 else str.isupper 423 ) 424 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
426 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 427 """Checks if text can be identified given an identify option. 428 429 Args: 430 text: The text to check. 431 identify: 432 `"always"` or `True`: Always returns `True`. 433 `"safe"`: Only returns `True` if the identifier is case-insensitive. 434 435 Returns: 436 Whether the given text can be identified. 437 """ 438 if identify is True or identify == "always": 439 return True 440 441 if identify == "safe": 442 return not self.case_sensitive(text) 443 444 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
446 def quote_identifier(self, expression: E, identify: bool = True) -> E: 447 """ 448 Adds quotes to a given identifier. 449 450 Args: 451 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 452 identify: If set to `False`, the quotes will only be added if the identifier is deemed 453 "unsafe", with respect to its characters and this dialect's normalization strategy. 454 """ 455 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 456 name = expression.this 457 expression.set( 458 "quoted", 459 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 460 ) 461 462 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
464 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 465 if isinstance(path, exp.Literal): 466 path_text = path.name 467 if path.is_number: 468 path_text = f"[{path_text}]" 469 470 try: 471 return parse_json_path(path_text) 472 except ParseError as e: 473 logger.warning(f"Invalid JSON path syntax. {str(e)}") 474 475 return path
523def if_sql( 524 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 525) -> t.Callable[[Generator, exp.If], str]: 526 def _if_sql(self: Generator, expression: exp.If) -> str: 527 return self.func( 528 name, 529 expression.this, 530 expression.args.get("true"), 531 expression.args.get("false") or false_value, 532 ) 533 534 return _if_sql
537def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 538 this = expression.this 539 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 540 this.replace(exp.cast(this, "json")) 541 542 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
599def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 600 this = self.sql(expression, "this") 601 substr = self.sql(expression, "substr") 602 position = self.sql(expression, "position") 603 if position: 604 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 605 return f"STRPOS({this}, {substr})"
614def var_map_sql( 615 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 616) -> str: 617 keys = expression.args["keys"] 618 values = expression.args["values"] 619 620 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 621 self.unsupported("Cannot convert array columns into map.") 622 return self.func(map_func_name, keys, values) 623 624 args = [] 625 for key, value in zip(keys.expressions, values.expressions): 626 args.append(self.sql(key)) 627 args.append(self.sql(value)) 628 629 return self.func(map_func_name, *args)
632def build_formatted_time( 633 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 634) -> t.Callable[[t.List], E]: 635 """Helper used for time expressions. 636 637 Args: 638 exp_class: the expression class to instantiate. 639 dialect: target sql dialect. 640 default: the default format, True being time. 641 642 Returns: 643 A callable that can be used to return the appropriately formatted time expression. 644 """ 645 646 def _builder(args: t.List): 647 return exp_class( 648 this=seq_get(args, 0), 649 format=Dialect[dialect].format_time( 650 seq_get(args, 1) 651 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 652 ), 653 ) 654 655 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
658def time_format( 659 dialect: DialectType = None, 660) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 661 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 662 """ 663 Returns the time format for a given expression, unless it's equivalent 664 to the default time format of the dialect of interest. 665 """ 666 time_format = self.format_time(expression) 667 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 668 669 return _time_format
672def build_date_delta( 673 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 674) -> t.Callable[[t.List], E]: 675 def _builder(args: t.List) -> E: 676 unit_based = len(args) == 3 677 this = args[2] if unit_based else seq_get(args, 0) 678 unit = args[0] if unit_based else exp.Literal.string("DAY") 679 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 680 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 681 682 return _builder
685def build_date_delta_with_interval( 686 expression_class: t.Type[E], 687) -> t.Callable[[t.List], t.Optional[E]]: 688 def _builder(args: t.List) -> t.Optional[E]: 689 if len(args) < 2: 690 return None 691 692 interval = args[1] 693 694 if not isinstance(interval, exp.Interval): 695 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 696 697 expression = interval.this 698 if expression and expression.is_string: 699 expression = exp.Literal.number(expression.this) 700 701 return expression_class( 702 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 703 ) 704 705 return _builder
717def date_add_interval_sql( 718 data_type: str, kind: str 719) -> t.Callable[[Generator, exp.Expression], str]: 720 def func(self: Generator, expression: exp.Expression) -> str: 721 this = self.sql(expression, "this") 722 unit = expression.args.get("unit") 723 unit = exp.var(unit.name.upper() if unit else "DAY") 724 interval = exp.Interval(this=expression.expression, unit=unit) 725 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 726 727 return func
736def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 737 if not expression.expression: 738 from sqlglot.optimizer.annotate_types import annotate_types 739 740 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 741 return self.sql(exp.cast(expression.this, to=target_type)) 742 if expression.text("expression").lower() in TIMEZONES: 743 return self.sql( 744 exp.AtTimeZone( 745 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 746 zone=expression.expression, 747 ) 748 ) 749 return self.func("TIMESTAMP", expression.this, expression.expression)
790def encode_decode_sql( 791 self: Generator, expression: exp.Expression, name: str, replace: bool = True 792) -> str: 793 charset = expression.args.get("charset") 794 if charset and charset.name.lower() != "utf-8": 795 self.unsupported(f"Expected utf-8 character set, got {charset}.") 796 797 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
810def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 811 cond = expression.this 812 813 if isinstance(expression.this, exp.Distinct): 814 cond = expression.this.expressions[0] 815 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 816 817 return self.func("sum", exp.func("if", cond, 1, 0))
820def trim_sql(self: Generator, expression: exp.Trim) -> str: 821 target = self.sql(expression, "this") 822 trim_type = self.sql(expression, "position") 823 remove_chars = self.sql(expression, "expression") 824 collation = self.sql(expression, "collation") 825 826 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 827 if not remove_chars and not collation: 828 return self.trim_sql(expression) 829 830 trim_type = f"{trim_type} " if trim_type else "" 831 remove_chars = f"{remove_chars} " if remove_chars else "" 832 from_part = "FROM " if trim_type or remove_chars else "" 833 collation = f" COLLATE {collation}" if collation else "" 834 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
855def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 856 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 857 if bad_args: 858 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 859 860 return self.func( 861 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 862 )
865def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 866 bad_args = list( 867 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 868 ) 869 if bad_args: 870 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 871 872 return self.func( 873 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 874 )
877def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 878 names = [] 879 for agg in aggregations: 880 if isinstance(agg, exp.Alias): 881 names.append(agg.alias) 882 else: 883 """ 884 This case corresponds to aggregations without aliases being used as suffixes 885 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 886 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 887 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 888 """ 889 agg_all_unquoted = agg.transform( 890 lambda node: ( 891 exp.Identifier(this=node.name, quoted=False) 892 if isinstance(node, exp.Identifier) 893 else node 894 ) 895 ) 896 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 897 898 return names
938def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 939 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 940 if expression.args.get("count"): 941 self.unsupported(f"Only two arguments are supported in function {name}.") 942 943 return self.func(name, expression.this, expression.expression) 944 945 return _arg_max_or_min_sql
948def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 949 this = expression.this.copy() 950 951 return_type = expression.return_type 952 if return_type.is_type(exp.DataType.Type.DATE): 953 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 954 # can truncate timestamp strings, because some dialects can't cast them to DATE 955 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 956 957 expression.this.replace(exp.cast(this, return_type)) 958 return expression
961def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 962 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 963 if cast and isinstance(expression, exp.TsOrDsAdd): 964 expression = ts_or_ds_add_cast(expression) 965 966 return self.func( 967 name, 968 exp.var(expression.text("unit").upper() or "DAY"), 969 expression.expression, 970 expression.this, 971 ) 972 973 return _delta_sql
976def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 977 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 978 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 979 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 980 981 return self.sql(exp.cast(minus_one_day, "date"))
984def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 985 """Remove table refs from columns in when statements.""" 986 alias = expression.this.args.get("alias") 987 988 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 989 return self.dialect.normalize_identifier(identifier).name if identifier else None 990 991 targets = {normalize(expression.this.this)} 992 993 if alias: 994 targets.add(normalize(alias.this)) 995 996 for when in expression.expressions: 997 when.transform( 998 lambda node: ( 999 exp.column(node.this) 1000 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1001 else node 1002 ), 1003 copy=False, 1004 ) 1005 1006 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1009def build_json_extract_path( 1010 expr_type: t.Type[F], zero_based_indexing: bool = True 1011) -> t.Callable[[t.List], F]: 1012 def _builder(args: t.List) -> F: 1013 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1014 for arg in args[1:]: 1015 if not isinstance(arg, exp.Literal): 1016 # We use the fallback parser because we can't really transpile non-literals safely 1017 return expr_type.from_arg_list(args) 1018 1019 text = arg.name 1020 if is_int(text): 1021 index = int(text) 1022 segments.append( 1023 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1024 ) 1025 else: 1026 segments.append(exp.JSONPathKey(this=text)) 1027 1028 # This is done to avoid failing in the expression validator due to the arg count 1029 del args[2:] 1030 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1031 1032 return _builder
1035def json_extract_segments( 1036 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1037) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1038 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1039 path = expression.expression 1040 if not isinstance(path, exp.JSONPath): 1041 return rename_func(name)(self, expression) 1042 1043 segments = [] 1044 for segment in path.expressions: 1045 path = self.sql(segment) 1046 if path: 1047 if isinstance(segment, exp.JSONPathPart) and ( 1048 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1049 ): 1050 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1051 1052 segments.append(path) 1053 1054 if op: 1055 return f" {op} ".join([self.sql(expression.this), *segments]) 1056 return self.func(name, expression.this, *segments) 1057 1058 return _json_extract_segments
1068def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1069 cond = expression.expression 1070 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1071 alias = cond.expressions[0] 1072 cond = cond.this 1073 elif isinstance(cond, exp.Predicate): 1074 alias = "_u" 1075 else: 1076 self.unsupported("Unsupported filter condition") 1077 return "" 1078 1079 unnest = exp.Unnest(expressions=[expression.this]) 1080 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1081 return self.sql(exp.Array(expressions=[filtered]))