sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get 12from sqlglot.jsonpath import parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26logger = logging.getLogger("sqlglot") 27 28 29class Dialects(str, Enum): 30 """Dialects supported by SQLGLot.""" 31 32 DIALECT = "" 33 34 ATHENA = "athena" 35 BIGQUERY = "bigquery" 36 CLICKHOUSE = "clickhouse" 37 DATABRICKS = "databricks" 38 DORIS = "doris" 39 DRILL = "drill" 40 DUCKDB = "duckdb" 41 HIVE = "hive" 42 MYSQL = "mysql" 43 ORACLE = "oracle" 44 POSTGRES = "postgres" 45 PRESTO = "presto" 46 PRQL = "prql" 47 REDSHIFT = "redshift" 48 SNOWFLAKE = "snowflake" 49 SPARK = "spark" 50 SPARK2 = "spark2" 51 SQLITE = "sqlite" 52 STARROCKS = "starrocks" 53 TABLEAU = "tableau" 54 TERADATA = "teradata" 55 TRINO = "trino" 56 TSQL = "tsql" 57 58 59class NormalizationStrategy(str, AutoName): 60 """Specifies the strategy according to which identifiers should be normalized.""" 61 62 LOWERCASE = auto() 63 """Unquoted identifiers are lowercased.""" 64 65 UPPERCASE = auto() 66 """Unquoted identifiers are uppercased.""" 67 68 CASE_SENSITIVE = auto() 69 """Always case-sensitive, regardless of quotes.""" 70 71 CASE_INSENSITIVE = auto() 72 """Always case-insensitive, regardless of quotes.""" 73 74 75class _Dialect(type): 76 classes: t.Dict[str, t.Type[Dialect]] = {} 77 78 def __eq__(cls, other: t.Any) -> bool: 79 if cls is other: 80 return True 81 if isinstance(other, str): 82 return cls is cls.get(other) 83 if isinstance(other, Dialect): 84 return cls is type(other) 85 86 return False 87 88 def __hash__(cls) -> int: 89 return hash(cls.__name__.lower()) 90 91 @classmethod 92 def __getitem__(cls, key: str) -> t.Type[Dialect]: 93 return cls.classes[key] 94 95 @classmethod 96 def get( 97 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 98 ) -> t.Optional[t.Type[Dialect]]: 99 return cls.classes.get(key, default) 100 101 def __new__(cls, clsname, bases, attrs): 102 klass = super().__new__(cls, clsname, bases, attrs) 103 enum = Dialects.__members__.get(clsname.upper()) 104 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 105 106 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 107 klass.FORMAT_TRIE = ( 108 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 109 ) 110 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 111 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 112 113 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 114 115 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 116 klass.parser_class = getattr(klass, "Parser", Parser) 117 klass.generator_class = getattr(klass, "Generator", Generator) 118 119 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 120 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 121 klass.tokenizer_class._IDENTIFIERS.items() 122 )[0] 123 124 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 125 return next( 126 ( 127 (s, e) 128 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 129 if t == token_type 130 ), 131 (None, None), 132 ) 133 134 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 135 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 136 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 137 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 138 139 if enum not in ("", "bigquery"): 140 klass.generator_class.SELECT_KINDS = () 141 142 if enum not in ("", "databricks", "hive", "spark", "spark2"): 143 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 144 for modifier in ("cluster", "distribute", "sort"): 145 modifier_transforms.pop(modifier, None) 146 147 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 148 149 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 150 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 151 TokenType.ANTI, 152 TokenType.SEMI, 153 } 154 155 return klass 156 157 158class Dialect(metaclass=_Dialect): 159 INDEX_OFFSET = 0 160 """The base index offset for arrays.""" 161 162 WEEK_OFFSET = 0 163 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 164 165 UNNEST_COLUMN_ONLY = False 166 """Whether `UNNEST` table aliases are treated as column aliases.""" 167 168 ALIAS_POST_TABLESAMPLE = False 169 """Whether the table alias comes after tablesample.""" 170 171 TABLESAMPLE_SIZE_IS_PERCENT = False 172 """Whether a size in the table sample clause represents percentage.""" 173 174 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 175 """Specifies the strategy according to which identifiers should be normalized.""" 176 177 IDENTIFIERS_CAN_START_WITH_DIGIT = False 178 """Whether an unquoted identifier can start with a digit.""" 179 180 DPIPE_IS_STRING_CONCAT = True 181 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 182 183 STRICT_STRING_CONCAT = False 184 """Whether `CONCAT`'s arguments must be strings.""" 185 186 SUPPORTS_USER_DEFINED_TYPES = True 187 """Whether user-defined data types are supported.""" 188 189 SUPPORTS_SEMI_ANTI_JOIN = True 190 """Whether `SEMI` or `ANTI` joins are supported.""" 191 192 NORMALIZE_FUNCTIONS: bool | str = "upper" 193 """ 194 Determines how function names are going to be normalized. 195 Possible values: 196 "upper" or True: Convert names to uppercase. 197 "lower": Convert names to lowercase. 198 False: Disables function name normalization. 199 """ 200 201 LOG_BASE_FIRST = True 202 """Whether the base comes first in the `LOG` function.""" 203 204 NULL_ORDERING = "nulls_are_small" 205 """ 206 Default `NULL` ordering method to use if not explicitly set. 207 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 208 """ 209 210 TYPED_DIVISION = False 211 """ 212 Whether the behavior of `a / b` depends on the types of `a` and `b`. 213 False means `a / b` is always float division. 214 True means `a / b` is integer division if both `a` and `b` are integers. 215 """ 216 217 SAFE_DIVISION = False 218 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 219 220 CONCAT_COALESCE = False 221 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 222 223 DATE_FORMAT = "'%Y-%m-%d'" 224 DATEINT_FORMAT = "'%Y%m%d'" 225 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 226 227 TIME_MAPPING: t.Dict[str, str] = {} 228 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 229 230 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 231 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 232 FORMAT_MAPPING: t.Dict[str, str] = {} 233 """ 234 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 235 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 236 """ 237 238 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 239 """Mapping of an unescaped escape sequence to the corresponding character.""" 240 241 PSEUDOCOLUMNS: t.Set[str] = set() 242 """ 243 Columns that are auto-generated by the engine corresponding to this dialect. 244 For example, such columns may be excluded from `SELECT *` queries. 245 """ 246 247 PREFER_CTE_ALIAS_COLUMN = False 248 """ 249 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 250 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 251 any projection aliases in the subquery. 252 253 For example, 254 WITH y(c) AS ( 255 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 256 ) SELECT c FROM y; 257 258 will be rewritten as 259 260 WITH y(c) AS ( 261 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 262 ) SELECT c FROM y; 263 """ 264 265 # --- Autofilled --- 266 267 tokenizer_class = Tokenizer 268 parser_class = Parser 269 generator_class = Generator 270 271 # A trie of the time_mapping keys 272 TIME_TRIE: t.Dict = {} 273 FORMAT_TRIE: t.Dict = {} 274 275 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 276 INVERSE_TIME_TRIE: t.Dict = {} 277 278 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 279 280 # Delimiters for string literals and identifiers 281 QUOTE_START = "'" 282 QUOTE_END = "'" 283 IDENTIFIER_START = '"' 284 IDENTIFIER_END = '"' 285 286 # Delimiters for bit, hex, byte and unicode literals 287 BIT_START: t.Optional[str] = None 288 BIT_END: t.Optional[str] = None 289 HEX_START: t.Optional[str] = None 290 HEX_END: t.Optional[str] = None 291 BYTE_START: t.Optional[str] = None 292 BYTE_END: t.Optional[str] = None 293 UNICODE_START: t.Optional[str] = None 294 UNICODE_END: t.Optional[str] = None 295 296 @classmethod 297 def get_or_raise(cls, dialect: DialectType) -> Dialect: 298 """ 299 Look up a dialect in the global dialect registry and return it if it exists. 300 301 Args: 302 dialect: The target dialect. If this is a string, it can be optionally followed by 303 additional key-value pairs that are separated by commas and are used to specify 304 dialect settings, such as whether the dialect's identifiers are case-sensitive. 305 306 Example: 307 >>> dialect = dialect_class = get_or_raise("duckdb") 308 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 309 310 Returns: 311 The corresponding Dialect instance. 312 """ 313 314 if not dialect: 315 return cls() 316 if isinstance(dialect, _Dialect): 317 return dialect() 318 if isinstance(dialect, Dialect): 319 return dialect 320 if isinstance(dialect, str): 321 try: 322 dialect_name, *kv_pairs = dialect.split(",") 323 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 324 except ValueError: 325 raise ValueError( 326 f"Invalid dialect format: '{dialect}'. " 327 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 328 ) 329 330 result = cls.get(dialect_name.strip()) 331 if not result: 332 from difflib import get_close_matches 333 334 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 335 if similar: 336 similar = f" Did you mean {similar}?" 337 338 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 339 340 return result(**kwargs) 341 342 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 343 344 @classmethod 345 def format_time( 346 cls, expression: t.Optional[str | exp.Expression] 347 ) -> t.Optional[exp.Expression]: 348 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 349 if isinstance(expression, str): 350 return exp.Literal.string( 351 # the time formats are quoted 352 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 353 ) 354 355 if expression and expression.is_string: 356 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 357 358 return expression 359 360 def __init__(self, **kwargs) -> None: 361 normalization_strategy = kwargs.get("normalization_strategy") 362 363 if normalization_strategy is None: 364 self.normalization_strategy = self.NORMALIZATION_STRATEGY 365 else: 366 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 367 368 def __eq__(self, other: t.Any) -> bool: 369 # Does not currently take dialect state into account 370 return type(self) == other 371 372 def __hash__(self) -> int: 373 # Does not currently take dialect state into account 374 return hash(type(self)) 375 376 def normalize_identifier(self, expression: E) -> E: 377 """ 378 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 379 380 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 381 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 382 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 383 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 384 385 There are also dialects like Spark, which are case-insensitive even when quotes are 386 present, and dialects like MySQL, whose resolution rules match those employed by the 387 underlying operating system, for example they may always be case-sensitive in Linux. 388 389 Finally, the normalization behavior of some engines can even be controlled through flags, 390 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 391 392 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 393 that it can analyze queries in the optimizer and successfully capture their semantics. 394 """ 395 if ( 396 isinstance(expression, exp.Identifier) 397 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 398 and ( 399 not expression.quoted 400 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 401 ) 402 ): 403 expression.set( 404 "this", 405 ( 406 expression.this.upper() 407 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 408 else expression.this.lower() 409 ), 410 ) 411 412 return expression 413 414 def case_sensitive(self, text: str) -> bool: 415 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 416 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 417 return False 418 419 unsafe = ( 420 str.islower 421 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 422 else str.isupper 423 ) 424 return any(unsafe(char) for char in text) 425 426 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 427 """Checks if text can be identified given an identify option. 428 429 Args: 430 text: The text to check. 431 identify: 432 `"always"` or `True`: Always returns `True`. 433 `"safe"`: Only returns `True` if the identifier is case-insensitive. 434 435 Returns: 436 Whether the given text can be identified. 437 """ 438 if identify is True or identify == "always": 439 return True 440 441 if identify == "safe": 442 return not self.case_sensitive(text) 443 444 return False 445 446 def quote_identifier(self, expression: E, identify: bool = True) -> E: 447 """ 448 Adds quotes to a given identifier. 449 450 Args: 451 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 452 identify: If set to `False`, the quotes will only be added if the identifier is deemed 453 "unsafe", with respect to its characters and this dialect's normalization strategy. 454 """ 455 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 456 name = expression.this 457 expression.set( 458 "quoted", 459 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 460 ) 461 462 return expression 463 464 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 465 if isinstance(path, exp.Literal): 466 path_text = path.name 467 if path.is_number: 468 path_text = f"[{path_text}]" 469 470 try: 471 return parse_json_path(path_text) 472 except ParseError as e: 473 logger.warning(f"Invalid JSON path syntax. {str(e)}") 474 475 return path 476 477 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 478 return self.parser(**opts).parse(self.tokenize(sql), sql) 479 480 def parse_into( 481 self, expression_type: exp.IntoType, sql: str, **opts 482 ) -> t.List[t.Optional[exp.Expression]]: 483 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 484 485 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 486 return self.generator(**opts).generate(expression, copy=copy) 487 488 def transpile(self, sql: str, **opts) -> t.List[str]: 489 return [ 490 self.generate(expression, copy=False, **opts) if expression else "" 491 for expression in self.parse(sql) 492 ] 493 494 def tokenize(self, sql: str) -> t.List[Token]: 495 return self.tokenizer.tokenize(sql) 496 497 @property 498 def tokenizer(self) -> Tokenizer: 499 if not hasattr(self, "_tokenizer"): 500 self._tokenizer = self.tokenizer_class(dialect=self) 501 return self._tokenizer 502 503 def parser(self, **opts) -> Parser: 504 return self.parser_class(dialect=self, **opts) 505 506 def generator(self, **opts) -> Generator: 507 return self.generator_class(dialect=self, **opts) 508 509 510DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 511 512 513def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 514 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 515 516 517def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 518 if expression.args.get("accuracy"): 519 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 520 return self.func("APPROX_COUNT_DISTINCT", expression.this) 521 522 523def if_sql( 524 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 525) -> t.Callable[[Generator, exp.If], str]: 526 def _if_sql(self: Generator, expression: exp.If) -> str: 527 return self.func( 528 name, 529 expression.this, 530 expression.args.get("true"), 531 expression.args.get("false") or false_value, 532 ) 533 534 return _if_sql 535 536 537def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 538 this = expression.this 539 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 540 this.replace(exp.cast(this, "json")) 541 542 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 543 544 545def inline_array_sql(self: Generator, expression: exp.Array) -> str: 546 return f"[{self.expressions(expression, flat=True)}]" 547 548 549def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 550 return self.like_sql( 551 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 552 ) 553 554 555def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 556 zone = self.sql(expression, "this") 557 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 558 559 560def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 561 if expression.args.get("recursive"): 562 self.unsupported("Recursive CTEs are unsupported") 563 expression.args["recursive"] = False 564 return self.with_sql(expression) 565 566 567def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 568 n = self.sql(expression, "this") 569 d = self.sql(expression, "expression") 570 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 571 572 573def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 574 self.unsupported("TABLESAMPLE unsupported") 575 return self.sql(expression.this) 576 577 578def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 579 self.unsupported("PIVOT unsupported") 580 return "" 581 582 583def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 584 return self.cast_sql(expression) 585 586 587def no_comment_column_constraint_sql( 588 self: Generator, expression: exp.CommentColumnConstraint 589) -> str: 590 self.unsupported("CommentColumnConstraint unsupported") 591 return "" 592 593 594def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 595 self.unsupported("MAP_FROM_ENTRIES unsupported") 596 return "" 597 598 599def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 600 this = self.sql(expression, "this") 601 substr = self.sql(expression, "substr") 602 position = self.sql(expression, "position") 603 if position: 604 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 605 return f"STRPOS({this}, {substr})" 606 607 608def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 609 return ( 610 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 611 ) 612 613 614def var_map_sql( 615 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 616) -> str: 617 keys = expression.args["keys"] 618 values = expression.args["values"] 619 620 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 621 self.unsupported("Cannot convert array columns into map.") 622 return self.func(map_func_name, keys, values) 623 624 args = [] 625 for key, value in zip(keys.expressions, values.expressions): 626 args.append(self.sql(key)) 627 args.append(self.sql(value)) 628 629 return self.func(map_func_name, *args) 630 631 632def build_formatted_time( 633 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 634) -> t.Callable[[t.List], E]: 635 """Helper used for time expressions. 636 637 Args: 638 exp_class: the expression class to instantiate. 639 dialect: target sql dialect. 640 default: the default format, True being time. 641 642 Returns: 643 A callable that can be used to return the appropriately formatted time expression. 644 """ 645 646 def _builder(args: t.List): 647 return exp_class( 648 this=seq_get(args, 0), 649 format=Dialect[dialect].format_time( 650 seq_get(args, 1) 651 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 652 ), 653 ) 654 655 return _builder 656 657 658def time_format( 659 dialect: DialectType = None, 660) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 661 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 662 """ 663 Returns the time format for a given expression, unless it's equivalent 664 to the default time format of the dialect of interest. 665 """ 666 time_format = self.format_time(expression) 667 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 668 669 return _time_format 670 671 672def build_date_delta( 673 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 674) -> t.Callable[[t.List], E]: 675 def _builder(args: t.List) -> E: 676 unit_based = len(args) == 3 677 this = args[2] if unit_based else seq_get(args, 0) 678 unit = args[0] if unit_based else exp.Literal.string("DAY") 679 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 680 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 681 682 return _builder 683 684 685def build_date_delta_with_interval( 686 expression_class: t.Type[E], 687) -> t.Callable[[t.List], t.Optional[E]]: 688 def _builder(args: t.List) -> t.Optional[E]: 689 if len(args) < 2: 690 return None 691 692 interval = args[1] 693 694 if not isinstance(interval, exp.Interval): 695 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 696 697 expression = interval.this 698 if expression and expression.is_string: 699 expression = exp.Literal.number(expression.this) 700 701 return expression_class( 702 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 703 ) 704 705 return _builder 706 707 708def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 709 unit = seq_get(args, 0) 710 this = seq_get(args, 1) 711 712 if isinstance(this, exp.Cast) and this.is_type("date"): 713 return exp.DateTrunc(unit=unit, this=this) 714 return exp.TimestampTrunc(this=this, unit=unit) 715 716 717def date_add_interval_sql( 718 data_type: str, kind: str 719) -> t.Callable[[Generator, exp.Expression], str]: 720 def func(self: Generator, expression: exp.Expression) -> str: 721 this = self.sql(expression, "this") 722 unit = expression.args.get("unit") 723 unit = exp.var(unit.name.upper() if unit else "DAY") 724 interval = exp.Interval(this=expression.expression, unit=unit) 725 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 726 727 return func 728 729 730def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 731 return self.func( 732 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 733 ) 734 735 736def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 737 if not expression.expression: 738 from sqlglot.optimizer.annotate_types import annotate_types 739 740 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 741 return self.sql(exp.cast(expression.this, to=target_type)) 742 if expression.text("expression").lower() in TIMEZONES: 743 return self.sql( 744 exp.AtTimeZone( 745 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 746 zone=expression.expression, 747 ) 748 ) 749 return self.func("TIMESTAMP", expression.this, expression.expression) 750 751 752def locate_to_strposition(args: t.List) -> exp.Expression: 753 return exp.StrPosition( 754 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 755 ) 756 757 758def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 759 return self.func( 760 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 761 ) 762 763 764def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 765 return self.sql( 766 exp.Substring( 767 this=expression.this, start=exp.Literal.number(1), length=expression.expression 768 ) 769 ) 770 771 772def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 773 return self.sql( 774 exp.Substring( 775 this=expression.this, 776 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 777 ) 778 ) 779 780 781def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 782 return self.sql(exp.cast(expression.this, "timestamp")) 783 784 785def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 786 return self.sql(exp.cast(expression.this, "date")) 787 788 789# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 790def encode_decode_sql( 791 self: Generator, expression: exp.Expression, name: str, replace: bool = True 792) -> str: 793 charset = expression.args.get("charset") 794 if charset and charset.name.lower() != "utf-8": 795 self.unsupported(f"Expected utf-8 character set, got {charset}.") 796 797 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 798 799 800def min_or_least(self: Generator, expression: exp.Min) -> str: 801 name = "LEAST" if expression.expressions else "MIN" 802 return rename_func(name)(self, expression) 803 804 805def max_or_greatest(self: Generator, expression: exp.Max) -> str: 806 name = "GREATEST" if expression.expressions else "MAX" 807 return rename_func(name)(self, expression) 808 809 810def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 811 cond = expression.this 812 813 if isinstance(expression.this, exp.Distinct): 814 cond = expression.this.expressions[0] 815 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 816 817 return self.func("sum", exp.func("if", cond, 1, 0)) 818 819 820def trim_sql(self: Generator, expression: exp.Trim) -> str: 821 target = self.sql(expression, "this") 822 trim_type = self.sql(expression, "position") 823 remove_chars = self.sql(expression, "expression") 824 collation = self.sql(expression, "collation") 825 826 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 827 if not remove_chars and not collation: 828 return self.trim_sql(expression) 829 830 trim_type = f"{trim_type} " if trim_type else "" 831 remove_chars = f"{remove_chars} " if remove_chars else "" 832 from_part = "FROM " if trim_type or remove_chars else "" 833 collation = f" COLLATE {collation}" if collation else "" 834 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 835 836 837def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 838 return self.func("STRPTIME", expression.this, self.format_time(expression)) 839 840 841def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 842 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 843 844 845def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 846 delim, *rest_args = expression.expressions 847 return self.sql( 848 reduce( 849 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 850 rest_args, 851 ) 852 ) 853 854 855def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 856 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 857 if bad_args: 858 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 859 860 return self.func( 861 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 862 ) 863 864 865def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 866 bad_args = list( 867 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 868 ) 869 if bad_args: 870 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 871 872 return self.func( 873 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 874 ) 875 876 877def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 878 names = [] 879 for agg in aggregations: 880 if isinstance(agg, exp.Alias): 881 names.append(agg.alias) 882 else: 883 """ 884 This case corresponds to aggregations without aliases being used as suffixes 885 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 886 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 887 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 888 """ 889 agg_all_unquoted = agg.transform( 890 lambda node: ( 891 exp.Identifier(this=node.name, quoted=False) 892 if isinstance(node, exp.Identifier) 893 else node 894 ) 895 ) 896 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 897 898 return names 899 900 901def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 902 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 903 904 905# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 906def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 907 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 908 909 910def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 911 return self.func("MAX", expression.this) 912 913 914def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 915 a = self.sql(expression.left) 916 b = self.sql(expression.right) 917 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 918 919 920def is_parse_json(expression: exp.Expression) -> bool: 921 return isinstance(expression, exp.ParseJSON) or ( 922 isinstance(expression, exp.Cast) and expression.is_type("json") 923 ) 924 925 926def isnull_to_is_null(args: t.List) -> exp.Expression: 927 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 928 929 930def generatedasidentitycolumnconstraint_sql( 931 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 932) -> str: 933 start = self.sql(expression, "start") or "1" 934 increment = self.sql(expression, "increment") or "1" 935 return f"IDENTITY({start}, {increment})" 936 937 938def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 939 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 940 if expression.args.get("count"): 941 self.unsupported(f"Only two arguments are supported in function {name}.") 942 943 return self.func(name, expression.this, expression.expression) 944 945 return _arg_max_or_min_sql 946 947 948def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 949 this = expression.this.copy() 950 951 return_type = expression.return_type 952 if return_type.is_type(exp.DataType.Type.DATE): 953 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 954 # can truncate timestamp strings, because some dialects can't cast them to DATE 955 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 956 957 expression.this.replace(exp.cast(this, return_type)) 958 return expression 959 960 961def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 962 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 963 if cast and isinstance(expression, exp.TsOrDsAdd): 964 expression = ts_or_ds_add_cast(expression) 965 966 return self.func( 967 name, 968 exp.var(expression.text("unit").upper() or "DAY"), 969 expression.expression, 970 expression.this, 971 ) 972 973 return _delta_sql 974 975 976def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 977 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 978 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 979 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 980 981 return self.sql(exp.cast(minus_one_day, "date")) 982 983 984def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 985 """Remove table refs from columns in when statements.""" 986 alias = expression.this.args.get("alias") 987 988 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 989 return self.dialect.normalize_identifier(identifier).name if identifier else None 990 991 targets = {normalize(expression.this.this)} 992 993 if alias: 994 targets.add(normalize(alias.this)) 995 996 for when in expression.expressions: 997 when.transform( 998 lambda node: ( 999 exp.column(node.this) 1000 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1001 else node 1002 ), 1003 copy=False, 1004 ) 1005 1006 return self.merge_sql(expression) 1007 1008 1009def build_json_extract_path( 1010 expr_type: t.Type[F], zero_based_indexing: bool = True 1011) -> t.Callable[[t.List], F]: 1012 def _builder(args: t.List) -> F: 1013 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1014 for arg in args[1:]: 1015 if not isinstance(arg, exp.Literal): 1016 # We use the fallback parser because we can't really transpile non-literals safely 1017 return expr_type.from_arg_list(args) 1018 1019 text = arg.name 1020 if is_int(text): 1021 index = int(text) 1022 segments.append( 1023 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1024 ) 1025 else: 1026 segments.append(exp.JSONPathKey(this=text)) 1027 1028 # This is done to avoid failing in the expression validator due to the arg count 1029 del args[2:] 1030 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1031 1032 return _builder 1033 1034 1035def json_extract_segments( 1036 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1037) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1038 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1039 path = expression.expression 1040 if not isinstance(path, exp.JSONPath): 1041 return rename_func(name)(self, expression) 1042 1043 segments = [] 1044 for segment in path.expressions: 1045 path = self.sql(segment) 1046 if path: 1047 if isinstance(segment, exp.JSONPathPart) and ( 1048 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1049 ): 1050 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1051 1052 segments.append(path) 1053 1054 if op: 1055 return f" {op} ".join([self.sql(expression.this), *segments]) 1056 return self.func(name, expression.this, *segments) 1057 1058 return _json_extract_segments 1059 1060 1061def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1062 if isinstance(expression.this, exp.JSONPathWildcard): 1063 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1064 1065 return expression.name 1066 1067 1068def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1069 cond = expression.expression 1070 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1071 alias = cond.expressions[0] 1072 cond = cond.this 1073 elif isinstance(cond, exp.Predicate): 1074 alias = "_u" 1075 else: 1076 self.unsupported("Unsupported filter condition") 1077 return "" 1078 1079 unnest = exp.Unnest(expressions=[expression.this]) 1080 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1081 return self.sql(exp.Array(expressions=[filtered]))
30class Dialects(str, Enum): 31 """Dialects supported by SQLGLot.""" 32 33 DIALECT = "" 34 35 ATHENA = "athena" 36 BIGQUERY = "bigquery" 37 CLICKHOUSE = "clickhouse" 38 DATABRICKS = "databricks" 39 DORIS = "doris" 40 DRILL = "drill" 41 DUCKDB = "duckdb" 42 HIVE = "hive" 43 MYSQL = "mysql" 44 ORACLE = "oracle" 45 POSTGRES = "postgres" 46 PRESTO = "presto" 47 PRQL = "prql" 48 REDSHIFT = "redshift" 49 SNOWFLAKE = "snowflake" 50 SPARK = "spark" 51 SPARK2 = "spark2" 52 SQLITE = "sqlite" 53 STARROCKS = "starrocks" 54 TABLEAU = "tableau" 55 TERADATA = "teradata" 56 TRINO = "trino" 57 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
60class NormalizationStrategy(str, AutoName): 61 """Specifies the strategy according to which identifiers should be normalized.""" 62 63 LOWERCASE = auto() 64 """Unquoted identifiers are lowercased.""" 65 66 UPPERCASE = auto() 67 """Unquoted identifiers are uppercased.""" 68 69 CASE_SENSITIVE = auto() 70 """Always case-sensitive, regardless of quotes.""" 71 72 CASE_INSENSITIVE = auto() 73 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
159class Dialect(metaclass=_Dialect): 160 INDEX_OFFSET = 0 161 """The base index offset for arrays.""" 162 163 WEEK_OFFSET = 0 164 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 165 166 UNNEST_COLUMN_ONLY = False 167 """Whether `UNNEST` table aliases are treated as column aliases.""" 168 169 ALIAS_POST_TABLESAMPLE = False 170 """Whether the table alias comes after tablesample.""" 171 172 TABLESAMPLE_SIZE_IS_PERCENT = False 173 """Whether a size in the table sample clause represents percentage.""" 174 175 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 176 """Specifies the strategy according to which identifiers should be normalized.""" 177 178 IDENTIFIERS_CAN_START_WITH_DIGIT = False 179 """Whether an unquoted identifier can start with a digit.""" 180 181 DPIPE_IS_STRING_CONCAT = True 182 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 183 184 STRICT_STRING_CONCAT = False 185 """Whether `CONCAT`'s arguments must be strings.""" 186 187 SUPPORTS_USER_DEFINED_TYPES = True 188 """Whether user-defined data types are supported.""" 189 190 SUPPORTS_SEMI_ANTI_JOIN = True 191 """Whether `SEMI` or `ANTI` joins are supported.""" 192 193 NORMALIZE_FUNCTIONS: bool | str = "upper" 194 """ 195 Determines how function names are going to be normalized. 196 Possible values: 197 "upper" or True: Convert names to uppercase. 198 "lower": Convert names to lowercase. 199 False: Disables function name normalization. 200 """ 201 202 LOG_BASE_FIRST = True 203 """Whether the base comes first in the `LOG` function.""" 204 205 NULL_ORDERING = "nulls_are_small" 206 """ 207 Default `NULL` ordering method to use if not explicitly set. 208 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 209 """ 210 211 TYPED_DIVISION = False 212 """ 213 Whether the behavior of `a / b` depends on the types of `a` and `b`. 214 False means `a / b` is always float division. 215 True means `a / b` is integer division if both `a` and `b` are integers. 216 """ 217 218 SAFE_DIVISION = False 219 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 220 221 CONCAT_COALESCE = False 222 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 223 224 DATE_FORMAT = "'%Y-%m-%d'" 225 DATEINT_FORMAT = "'%Y%m%d'" 226 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 227 228 TIME_MAPPING: t.Dict[str, str] = {} 229 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 230 231 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 232 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 233 FORMAT_MAPPING: t.Dict[str, str] = {} 234 """ 235 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 236 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 237 """ 238 239 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 240 """Mapping of an unescaped escape sequence to the corresponding character.""" 241 242 PSEUDOCOLUMNS: t.Set[str] = set() 243 """ 244 Columns that are auto-generated by the engine corresponding to this dialect. 245 For example, such columns may be excluded from `SELECT *` queries. 246 """ 247 248 PREFER_CTE_ALIAS_COLUMN = False 249 """ 250 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 251 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 252 any projection aliases in the subquery. 253 254 For example, 255 WITH y(c) AS ( 256 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 257 ) SELECT c FROM y; 258 259 will be rewritten as 260 261 WITH y(c) AS ( 262 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 263 ) SELECT c FROM y; 264 """ 265 266 # --- Autofilled --- 267 268 tokenizer_class = Tokenizer 269 parser_class = Parser 270 generator_class = Generator 271 272 # A trie of the time_mapping keys 273 TIME_TRIE: t.Dict = {} 274 FORMAT_TRIE: t.Dict = {} 275 276 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 277 INVERSE_TIME_TRIE: t.Dict = {} 278 279 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 280 281 # Delimiters for string literals and identifiers 282 QUOTE_START = "'" 283 QUOTE_END = "'" 284 IDENTIFIER_START = '"' 285 IDENTIFIER_END = '"' 286 287 # Delimiters for bit, hex, byte and unicode literals 288 BIT_START: t.Optional[str] = None 289 BIT_END: t.Optional[str] = None 290 HEX_START: t.Optional[str] = None 291 HEX_END: t.Optional[str] = None 292 BYTE_START: t.Optional[str] = None 293 BYTE_END: t.Optional[str] = None 294 UNICODE_START: t.Optional[str] = None 295 UNICODE_END: t.Optional[str] = None 296 297 @classmethod 298 def get_or_raise(cls, dialect: DialectType) -> Dialect: 299 """ 300 Look up a dialect in the global dialect registry and return it if it exists. 301 302 Args: 303 dialect: The target dialect. If this is a string, it can be optionally followed by 304 additional key-value pairs that are separated by commas and are used to specify 305 dialect settings, such as whether the dialect's identifiers are case-sensitive. 306 307 Example: 308 >>> dialect = dialect_class = get_or_raise("duckdb") 309 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 310 311 Returns: 312 The corresponding Dialect instance. 313 """ 314 315 if not dialect: 316 return cls() 317 if isinstance(dialect, _Dialect): 318 return dialect() 319 if isinstance(dialect, Dialect): 320 return dialect 321 if isinstance(dialect, str): 322 try: 323 dialect_name, *kv_pairs = dialect.split(",") 324 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 325 except ValueError: 326 raise ValueError( 327 f"Invalid dialect format: '{dialect}'. " 328 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 329 ) 330 331 result = cls.get(dialect_name.strip()) 332 if not result: 333 from difflib import get_close_matches 334 335 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 336 if similar: 337 similar = f" Did you mean {similar}?" 338 339 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 340 341 return result(**kwargs) 342 343 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 344 345 @classmethod 346 def format_time( 347 cls, expression: t.Optional[str | exp.Expression] 348 ) -> t.Optional[exp.Expression]: 349 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 350 if isinstance(expression, str): 351 return exp.Literal.string( 352 # the time formats are quoted 353 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 354 ) 355 356 if expression and expression.is_string: 357 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 358 359 return expression 360 361 def __init__(self, **kwargs) -> None: 362 normalization_strategy = kwargs.get("normalization_strategy") 363 364 if normalization_strategy is None: 365 self.normalization_strategy = self.NORMALIZATION_STRATEGY 366 else: 367 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 368 369 def __eq__(self, other: t.Any) -> bool: 370 # Does not currently take dialect state into account 371 return type(self) == other 372 373 def __hash__(self) -> int: 374 # Does not currently take dialect state into account 375 return hash(type(self)) 376 377 def normalize_identifier(self, expression: E) -> E: 378 """ 379 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 380 381 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 382 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 383 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 384 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 385 386 There are also dialects like Spark, which are case-insensitive even when quotes are 387 present, and dialects like MySQL, whose resolution rules match those employed by the 388 underlying operating system, for example they may always be case-sensitive in Linux. 389 390 Finally, the normalization behavior of some engines can even be controlled through flags, 391 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 392 393 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 394 that it can analyze queries in the optimizer and successfully capture their semantics. 395 """ 396 if ( 397 isinstance(expression, exp.Identifier) 398 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 399 and ( 400 not expression.quoted 401 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 402 ) 403 ): 404 expression.set( 405 "this", 406 ( 407 expression.this.upper() 408 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 409 else expression.this.lower() 410 ), 411 ) 412 413 return expression 414 415 def case_sensitive(self, text: str) -> bool: 416 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 417 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 418 return False 419 420 unsafe = ( 421 str.islower 422 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 423 else str.isupper 424 ) 425 return any(unsafe(char) for char in text) 426 427 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 428 """Checks if text can be identified given an identify option. 429 430 Args: 431 text: The text to check. 432 identify: 433 `"always"` or `True`: Always returns `True`. 434 `"safe"`: Only returns `True` if the identifier is case-insensitive. 435 436 Returns: 437 Whether the given text can be identified. 438 """ 439 if identify is True or identify == "always": 440 return True 441 442 if identify == "safe": 443 return not self.case_sensitive(text) 444 445 return False 446 447 def quote_identifier(self, expression: E, identify: bool = True) -> E: 448 """ 449 Adds quotes to a given identifier. 450 451 Args: 452 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 453 identify: If set to `False`, the quotes will only be added if the identifier is deemed 454 "unsafe", with respect to its characters and this dialect's normalization strategy. 455 """ 456 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 457 name = expression.this 458 expression.set( 459 "quoted", 460 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 461 ) 462 463 return expression 464 465 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 466 if isinstance(path, exp.Literal): 467 path_text = path.name 468 if path.is_number: 469 path_text = f"[{path_text}]" 470 471 try: 472 return parse_json_path(path_text) 473 except ParseError as e: 474 logger.warning(f"Invalid JSON path syntax. {str(e)}") 475 476 return path 477 478 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 479 return self.parser(**opts).parse(self.tokenize(sql), sql) 480 481 def parse_into( 482 self, expression_type: exp.IntoType, sql: str, **opts 483 ) -> t.List[t.Optional[exp.Expression]]: 484 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 485 486 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 487 return self.generator(**opts).generate(expression, copy=copy) 488 489 def transpile(self, sql: str, **opts) -> t.List[str]: 490 return [ 491 self.generate(expression, copy=False, **opts) if expression else "" 492 for expression in self.parse(sql) 493 ] 494 495 def tokenize(self, sql: str) -> t.List[Token]: 496 return self.tokenizer.tokenize(sql) 497 498 @property 499 def tokenizer(self) -> Tokenizer: 500 if not hasattr(self, "_tokenizer"): 501 self._tokenizer = self.tokenizer_class(dialect=self) 502 return self._tokenizer 503 504 def parser(self, **opts) -> Parser: 505 return self.parser_class(dialect=self, **opts) 506 507 def generator(self, **opts) -> Generator: 508 return self.generator_class(dialect=self, **opts)
361 def __init__(self, **kwargs) -> None: 362 normalization_strategy = kwargs.get("normalization_strategy") 363 364 if normalization_strategy is None: 365 self.normalization_strategy = self.NORMALIZATION_STRATEGY 366 else: 367 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
297 @classmethod 298 def get_or_raise(cls, dialect: DialectType) -> Dialect: 299 """ 300 Look up a dialect in the global dialect registry and return it if it exists. 301 302 Args: 303 dialect: The target dialect. If this is a string, it can be optionally followed by 304 additional key-value pairs that are separated by commas and are used to specify 305 dialect settings, such as whether the dialect's identifiers are case-sensitive. 306 307 Example: 308 >>> dialect = dialect_class = get_or_raise("duckdb") 309 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 310 311 Returns: 312 The corresponding Dialect instance. 313 """ 314 315 if not dialect: 316 return cls() 317 if isinstance(dialect, _Dialect): 318 return dialect() 319 if isinstance(dialect, Dialect): 320 return dialect 321 if isinstance(dialect, str): 322 try: 323 dialect_name, *kv_pairs = dialect.split(",") 324 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 325 except ValueError: 326 raise ValueError( 327 f"Invalid dialect format: '{dialect}'. " 328 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 329 ) 330 331 result = cls.get(dialect_name.strip()) 332 if not result: 333 from difflib import get_close_matches 334 335 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 336 if similar: 337 similar = f" Did you mean {similar}?" 338 339 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 340 341 return result(**kwargs) 342 343 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
345 @classmethod 346 def format_time( 347 cls, expression: t.Optional[str | exp.Expression] 348 ) -> t.Optional[exp.Expression]: 349 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 350 if isinstance(expression, str): 351 return exp.Literal.string( 352 # the time formats are quoted 353 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 354 ) 355 356 if expression and expression.is_string: 357 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 358 359 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
377 def normalize_identifier(self, expression: E) -> E: 378 """ 379 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 380 381 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 382 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 383 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 384 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 385 386 There are also dialects like Spark, which are case-insensitive even when quotes are 387 present, and dialects like MySQL, whose resolution rules match those employed by the 388 underlying operating system, for example they may always be case-sensitive in Linux. 389 390 Finally, the normalization behavior of some engines can even be controlled through flags, 391 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 392 393 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 394 that it can analyze queries in the optimizer and successfully capture their semantics. 395 """ 396 if ( 397 isinstance(expression, exp.Identifier) 398 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 399 and ( 400 not expression.quoted 401 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 402 ) 403 ): 404 expression.set( 405 "this", 406 ( 407 expression.this.upper() 408 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 409 else expression.this.lower() 410 ), 411 ) 412 413 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
415 def case_sensitive(self, text: str) -> bool: 416 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 417 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 418 return False 419 420 unsafe = ( 421 str.islower 422 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 423 else str.isupper 424 ) 425 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
427 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 428 """Checks if text can be identified given an identify option. 429 430 Args: 431 text: The text to check. 432 identify: 433 `"always"` or `True`: Always returns `True`. 434 `"safe"`: Only returns `True` if the identifier is case-insensitive. 435 436 Returns: 437 Whether the given text can be identified. 438 """ 439 if identify is True or identify == "always": 440 return True 441 442 if identify == "safe": 443 return not self.case_sensitive(text) 444 445 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
447 def quote_identifier(self, expression: E, identify: bool = True) -> E: 448 """ 449 Adds quotes to a given identifier. 450 451 Args: 452 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 453 identify: If set to `False`, the quotes will only be added if the identifier is deemed 454 "unsafe", with respect to its characters and this dialect's normalization strategy. 455 """ 456 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 457 name = expression.this 458 expression.set( 459 "quoted", 460 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 461 ) 462 463 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
465 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 466 if isinstance(path, exp.Literal): 467 path_text = path.name 468 if path.is_number: 469 path_text = f"[{path_text}]" 470 471 try: 472 return parse_json_path(path_text) 473 except ParseError as e: 474 logger.warning(f"Invalid JSON path syntax. {str(e)}") 475 476 return path
524def if_sql( 525 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 526) -> t.Callable[[Generator, exp.If], str]: 527 def _if_sql(self: Generator, expression: exp.If) -> str: 528 return self.func( 529 name, 530 expression.this, 531 expression.args.get("true"), 532 expression.args.get("false") or false_value, 533 ) 534 535 return _if_sql
538def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 539 this = expression.this 540 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 541 this.replace(exp.cast(this, "json")) 542 543 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
600def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 601 this = self.sql(expression, "this") 602 substr = self.sql(expression, "substr") 603 position = self.sql(expression, "position") 604 if position: 605 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 606 return f"STRPOS({this}, {substr})"
615def var_map_sql( 616 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 617) -> str: 618 keys = expression.args["keys"] 619 values = expression.args["values"] 620 621 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 622 self.unsupported("Cannot convert array columns into map.") 623 return self.func(map_func_name, keys, values) 624 625 args = [] 626 for key, value in zip(keys.expressions, values.expressions): 627 args.append(self.sql(key)) 628 args.append(self.sql(value)) 629 630 return self.func(map_func_name, *args)
633def build_formatted_time( 634 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 635) -> t.Callable[[t.List], E]: 636 """Helper used for time expressions. 637 638 Args: 639 exp_class: the expression class to instantiate. 640 dialect: target sql dialect. 641 default: the default format, True being time. 642 643 Returns: 644 A callable that can be used to return the appropriately formatted time expression. 645 """ 646 647 def _builder(args: t.List): 648 return exp_class( 649 this=seq_get(args, 0), 650 format=Dialect[dialect].format_time( 651 seq_get(args, 1) 652 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 653 ), 654 ) 655 656 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
659def time_format( 660 dialect: DialectType = None, 661) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 662 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 663 """ 664 Returns the time format for a given expression, unless it's equivalent 665 to the default time format of the dialect of interest. 666 """ 667 time_format = self.format_time(expression) 668 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 669 670 return _time_format
673def build_date_delta( 674 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 675) -> t.Callable[[t.List], E]: 676 def _builder(args: t.List) -> E: 677 unit_based = len(args) == 3 678 this = args[2] if unit_based else seq_get(args, 0) 679 unit = args[0] if unit_based else exp.Literal.string("DAY") 680 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 681 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 682 683 return _builder
686def build_date_delta_with_interval( 687 expression_class: t.Type[E], 688) -> t.Callable[[t.List], t.Optional[E]]: 689 def _builder(args: t.List) -> t.Optional[E]: 690 if len(args) < 2: 691 return None 692 693 interval = args[1] 694 695 if not isinstance(interval, exp.Interval): 696 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 697 698 expression = interval.this 699 if expression and expression.is_string: 700 expression = exp.Literal.number(expression.this) 701 702 return expression_class( 703 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 704 ) 705 706 return _builder
718def date_add_interval_sql( 719 data_type: str, kind: str 720) -> t.Callable[[Generator, exp.Expression], str]: 721 def func(self: Generator, expression: exp.Expression) -> str: 722 this = self.sql(expression, "this") 723 unit = expression.args.get("unit") 724 unit = exp.var(unit.name.upper() if unit else "DAY") 725 interval = exp.Interval(this=expression.expression, unit=unit) 726 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 727 728 return func
737def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 738 if not expression.expression: 739 from sqlglot.optimizer.annotate_types import annotate_types 740 741 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 742 return self.sql(exp.cast(expression.this, to=target_type)) 743 if expression.text("expression").lower() in TIMEZONES: 744 return self.sql( 745 exp.AtTimeZone( 746 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 747 zone=expression.expression, 748 ) 749 ) 750 return self.func("TIMESTAMP", expression.this, expression.expression)
791def encode_decode_sql( 792 self: Generator, expression: exp.Expression, name: str, replace: bool = True 793) -> str: 794 charset = expression.args.get("charset") 795 if charset and charset.name.lower() != "utf-8": 796 self.unsupported(f"Expected utf-8 character set, got {charset}.") 797 798 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
811def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 812 cond = expression.this 813 814 if isinstance(expression.this, exp.Distinct): 815 cond = expression.this.expressions[0] 816 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 817 818 return self.func("sum", exp.func("if", cond, 1, 0))
821def trim_sql(self: Generator, expression: exp.Trim) -> str: 822 target = self.sql(expression, "this") 823 trim_type = self.sql(expression, "position") 824 remove_chars = self.sql(expression, "expression") 825 collation = self.sql(expression, "collation") 826 827 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 828 if not remove_chars and not collation: 829 return self.trim_sql(expression) 830 831 trim_type = f"{trim_type} " if trim_type else "" 832 remove_chars = f"{remove_chars} " if remove_chars else "" 833 from_part = "FROM " if trim_type or remove_chars else "" 834 collation = f" COLLATE {collation}" if collation else "" 835 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
856def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 857 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 858 if bad_args: 859 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 860 861 return self.func( 862 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 863 )
866def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 867 bad_args = list( 868 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 869 ) 870 if bad_args: 871 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 872 873 return self.func( 874 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 875 )
878def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 879 names = [] 880 for agg in aggregations: 881 if isinstance(agg, exp.Alias): 882 names.append(agg.alias) 883 else: 884 """ 885 This case corresponds to aggregations without aliases being used as suffixes 886 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 887 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 888 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 889 """ 890 agg_all_unquoted = agg.transform( 891 lambda node: ( 892 exp.Identifier(this=node.name, quoted=False) 893 if isinstance(node, exp.Identifier) 894 else node 895 ) 896 ) 897 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 898 899 return names
939def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 940 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 941 if expression.args.get("count"): 942 self.unsupported(f"Only two arguments are supported in function {name}.") 943 944 return self.func(name, expression.this, expression.expression) 945 946 return _arg_max_or_min_sql
949def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 950 this = expression.this.copy() 951 952 return_type = expression.return_type 953 if return_type.is_type(exp.DataType.Type.DATE): 954 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 955 # can truncate timestamp strings, because some dialects can't cast them to DATE 956 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 957 958 expression.this.replace(exp.cast(this, return_type)) 959 return expression
962def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 963 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 964 if cast and isinstance(expression, exp.TsOrDsAdd): 965 expression = ts_or_ds_add_cast(expression) 966 967 return self.func( 968 name, 969 exp.var(expression.text("unit").upper() or "DAY"), 970 expression.expression, 971 expression.this, 972 ) 973 974 return _delta_sql
977def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 978 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 979 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 980 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 981 982 return self.sql(exp.cast(minus_one_day, "date"))
985def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 986 """Remove table refs from columns in when statements.""" 987 alias = expression.this.args.get("alias") 988 989 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 990 return self.dialect.normalize_identifier(identifier).name if identifier else None 991 992 targets = {normalize(expression.this.this)} 993 994 if alias: 995 targets.add(normalize(alias.this)) 996 997 for when in expression.expressions: 998 when.transform( 999 lambda node: ( 1000 exp.column(node.this) 1001 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1002 else node 1003 ), 1004 copy=False, 1005 ) 1006 1007 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1010def build_json_extract_path( 1011 expr_type: t.Type[F], zero_based_indexing: bool = True 1012) -> t.Callable[[t.List], F]: 1013 def _builder(args: t.List) -> F: 1014 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1015 for arg in args[1:]: 1016 if not isinstance(arg, exp.Literal): 1017 # We use the fallback parser because we can't really transpile non-literals safely 1018 return expr_type.from_arg_list(args) 1019 1020 text = arg.name 1021 if is_int(text): 1022 index = int(text) 1023 segments.append( 1024 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1025 ) 1026 else: 1027 segments.append(exp.JSONPathKey(this=text)) 1028 1029 # This is done to avoid failing in the expression validator due to the arg count 1030 del args[2:] 1031 return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) 1032 1033 return _builder
1036def json_extract_segments( 1037 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1038) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1039 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1040 path = expression.expression 1041 if not isinstance(path, exp.JSONPath): 1042 return rename_func(name)(self, expression) 1043 1044 segments = [] 1045 for segment in path.expressions: 1046 path = self.sql(segment) 1047 if path: 1048 if isinstance(segment, exp.JSONPathPart) and ( 1049 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1050 ): 1051 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1052 1053 segments.append(path) 1054 1055 if op: 1056 return f" {op} ".join([self.sql(expression.this), *segments]) 1057 return self.func(name, expression.this, *segments) 1058 1059 return _json_extract_segments
1069def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1070 cond = expression.expression 1071 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1072 alias = cond.expressions[0] 1073 cond = cond.this 1074 elif isinstance(cond, exp.Predicate): 1075 alias = "_u" 1076 else: 1077 self.unsupported("Unsupported filter condition") 1078 return "" 1079 1080 unnest = exp.Unnest(expressions=[expression.this]) 1081 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1082 return self.sql(exp.Array(expressions=[filtered]))