sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum, auto 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot.errors import ParseError 9from sqlglot.generator import Generator 10from sqlglot.helper import AutoName, flatten, seq_get 11from sqlglot.parser import Parser 12from sqlglot.time import TIMEZONES, format_time 13from sqlglot.tokens import Token, Tokenizer, TokenType 14from sqlglot.trie import new_trie 15 16DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 17DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 18 19if t.TYPE_CHECKING: 20 from sqlglot._typing import B, E 21 22 23class Dialects(str, Enum): 24 """Dialects supported by SQLGLot.""" 25 26 DIALECT = "" 27 28 BIGQUERY = "bigquery" 29 CLICKHOUSE = "clickhouse" 30 DATABRICKS = "databricks" 31 DORIS = "doris" 32 DRILL = "drill" 33 DUCKDB = "duckdb" 34 HIVE = "hive" 35 MYSQL = "mysql" 36 ORACLE = "oracle" 37 POSTGRES = "postgres" 38 PRESTO = "presto" 39 REDSHIFT = "redshift" 40 SNOWFLAKE = "snowflake" 41 SPARK = "spark" 42 SPARK2 = "spark2" 43 SQLITE = "sqlite" 44 STARROCKS = "starrocks" 45 TABLEAU = "tableau" 46 TERADATA = "teradata" 47 TRINO = "trino" 48 TSQL = "tsql" 49 50 51class NormalizationStrategy(str, AutoName): 52 """Specifies the strategy according to which identifiers should be normalized.""" 53 54 LOWERCASE = auto() 55 """Unquoted identifiers are lowercased.""" 56 57 UPPERCASE = auto() 58 """Unquoted identifiers are uppercased.""" 59 60 CASE_SENSITIVE = auto() 61 """Always case-sensitive, regardless of quotes.""" 62 63 CASE_INSENSITIVE = auto() 64 """Always case-insensitive, regardless of quotes.""" 65 66 67class _Dialect(type): 68 classes: t.Dict[str, t.Type[Dialect]] = {} 69 70 def __eq__(cls, other: t.Any) -> bool: 71 if cls is other: 72 return True 73 if isinstance(other, str): 74 return cls is cls.get(other) 75 if isinstance(other, Dialect): 76 return cls is type(other) 77 78 return False 79 80 def __hash__(cls) -> int: 81 return hash(cls.__name__.lower()) 82 83 @classmethod 84 def __getitem__(cls, key: str) -> t.Type[Dialect]: 85 return cls.classes[key] 86 87 @classmethod 88 def get( 89 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 90 ) -> t.Optional[t.Type[Dialect]]: 91 return cls.classes.get(key, default) 92 93 def __new__(cls, clsname, bases, attrs): 94 klass = super().__new__(cls, clsname, bases, attrs) 95 enum = Dialects.__members__.get(clsname.upper()) 96 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 97 98 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 99 klass.FORMAT_TRIE = ( 100 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 101 ) 102 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 103 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 104 105 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 106 107 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 108 klass.parser_class = getattr(klass, "Parser", Parser) 109 klass.generator_class = getattr(klass, "Generator", Generator) 110 111 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 112 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 113 klass.tokenizer_class._IDENTIFIERS.items() 114 )[0] 115 116 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 117 return next( 118 ( 119 (s, e) 120 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 121 if t == token_type 122 ), 123 (None, None), 124 ) 125 126 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 127 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 128 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 129 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 130 131 if enum not in ("", "bigquery"): 132 klass.generator_class.SELECT_KINDS = () 133 134 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 135 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 136 TokenType.ANTI, 137 TokenType.SEMI, 138 } 139 140 return klass 141 142 143class Dialect(metaclass=_Dialect): 144 INDEX_OFFSET = 0 145 """Determines the base index offset for arrays.""" 146 147 WEEK_OFFSET = 0 148 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 149 150 UNNEST_COLUMN_ONLY = False 151 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 152 153 ALIAS_POST_TABLESAMPLE = False 154 """Determines whether or not the table alias comes after tablesample.""" 155 156 TABLESAMPLE_SIZE_IS_PERCENT = False 157 """Determines whether or not a size in the table sample clause represents percentage.""" 158 159 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 160 """Specifies the strategy according to which identifiers should be normalized.""" 161 162 IDENTIFIERS_CAN_START_WITH_DIGIT = False 163 """Determines whether or not an unquoted identifier can start with a digit.""" 164 165 DPIPE_IS_STRING_CONCAT = True 166 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 167 168 STRICT_STRING_CONCAT = False 169 """Determines whether or not `CONCAT`'s arguments must be strings.""" 170 171 SUPPORTS_USER_DEFINED_TYPES = True 172 """Determines whether or not user-defined data types are supported.""" 173 174 SUPPORTS_SEMI_ANTI_JOIN = True 175 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 176 177 NORMALIZE_FUNCTIONS: bool | str = "upper" 178 """Determines how function names are going to be normalized.""" 179 180 LOG_BASE_FIRST = True 181 """Determines whether the base comes first in the `LOG` function.""" 182 183 NULL_ORDERING = "nulls_are_small" 184 """ 185 Indicates the default `NULL` ordering method to use if not explicitly set. 186 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 187 """ 188 189 TYPED_DIVISION = False 190 """ 191 Whether the behavior of `a / b` depends on the types of `a` and `b`. 192 False means `a / b` is always float division. 193 True means `a / b` is integer division if both `a` and `b` are integers. 194 """ 195 196 SAFE_DIVISION = False 197 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 198 199 CONCAT_COALESCE = False 200 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 201 202 DATE_FORMAT = "'%Y-%m-%d'" 203 DATEINT_FORMAT = "'%Y%m%d'" 204 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 205 206 TIME_MAPPING: t.Dict[str, str] = {} 207 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 208 209 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 210 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 211 FORMAT_MAPPING: t.Dict[str, str] = {} 212 """ 213 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 214 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 215 """ 216 217 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 218 """Mapping of an unescaped escape sequence to the corresponding character.""" 219 220 PSEUDOCOLUMNS: t.Set[str] = set() 221 """ 222 Columns that are auto-generated by the engine corresponding to this dialect. 223 For example, such columns may be excluded from `SELECT *` queries. 224 """ 225 226 PREFER_CTE_ALIAS_COLUMN = False 227 """ 228 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 229 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 230 any projection aliases in the subquery. 231 232 For example, 233 WITH y(c) AS ( 234 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 235 ) SELECT c FROM y; 236 237 will be rewritten as 238 239 WITH y(c) AS ( 240 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 241 ) SELECT c FROM y; 242 """ 243 244 # --- Autofilled --- 245 246 tokenizer_class = Tokenizer 247 parser_class = Parser 248 generator_class = Generator 249 250 # A trie of the time_mapping keys 251 TIME_TRIE: t.Dict = {} 252 FORMAT_TRIE: t.Dict = {} 253 254 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 255 INVERSE_TIME_TRIE: t.Dict = {} 256 257 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 258 259 # Delimiters for quotes, identifiers and the corresponding escape characters 260 QUOTE_START = "'" 261 QUOTE_END = "'" 262 IDENTIFIER_START = '"' 263 IDENTIFIER_END = '"' 264 265 # Delimiters for bit, hex, byte and unicode literals 266 BIT_START: t.Optional[str] = None 267 BIT_END: t.Optional[str] = None 268 HEX_START: t.Optional[str] = None 269 HEX_END: t.Optional[str] = None 270 BYTE_START: t.Optional[str] = None 271 BYTE_END: t.Optional[str] = None 272 UNICODE_START: t.Optional[str] = None 273 UNICODE_END: t.Optional[str] = None 274 275 @classmethod 276 def get_or_raise(cls, dialect: DialectType) -> Dialect: 277 """ 278 Look up a dialect in the global dialect registry and return it if it exists. 279 280 Args: 281 dialect: The target dialect. If this is a string, it can be optionally followed by 282 additional key-value pairs that are separated by commas and are used to specify 283 dialect settings, such as whether the dialect's identifiers are case-sensitive. 284 285 Example: 286 >>> dialect = dialect_class = get_or_raise("duckdb") 287 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 288 289 Returns: 290 The corresponding Dialect instance. 291 """ 292 293 if not dialect: 294 return cls() 295 if isinstance(dialect, _Dialect): 296 return dialect() 297 if isinstance(dialect, Dialect): 298 return dialect 299 if isinstance(dialect, str): 300 try: 301 dialect_name, *kv_pairs = dialect.split(",") 302 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 303 except ValueError: 304 raise ValueError( 305 f"Invalid dialect format: '{dialect}'. " 306 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 307 ) 308 309 result = cls.get(dialect_name.strip()) 310 if not result: 311 from difflib import get_close_matches 312 313 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 314 if similar: 315 similar = f" Did you mean {similar}?" 316 317 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 318 319 return result(**kwargs) 320 321 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 322 323 @classmethod 324 def format_time( 325 cls, expression: t.Optional[str | exp.Expression] 326 ) -> t.Optional[exp.Expression]: 327 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 328 if isinstance(expression, str): 329 return exp.Literal.string( 330 # the time formats are quoted 331 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 332 ) 333 334 if expression and expression.is_string: 335 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 336 337 return expression 338 339 def __init__(self, **kwargs) -> None: 340 normalization_strategy = kwargs.get("normalization_strategy") 341 342 if normalization_strategy is None: 343 self.normalization_strategy = self.NORMALIZATION_STRATEGY 344 else: 345 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 346 347 def __eq__(self, other: t.Any) -> bool: 348 # Does not currently take dialect state into account 349 return type(self) == other 350 351 def __hash__(self) -> int: 352 # Does not currently take dialect state into account 353 return hash(type(self)) 354 355 def normalize_identifier(self, expression: E) -> E: 356 """ 357 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 358 359 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 360 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 361 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 362 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 363 364 There are also dialects like Spark, which are case-insensitive even when quotes are 365 present, and dialects like MySQL, whose resolution rules match those employed by the 366 underlying operating system, for example they may always be case-sensitive in Linux. 367 368 Finally, the normalization behavior of some engines can even be controlled through flags, 369 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 370 371 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 372 that it can analyze queries in the optimizer and successfully capture their semantics. 373 """ 374 if ( 375 isinstance(expression, exp.Identifier) 376 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 377 and ( 378 not expression.quoted 379 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 380 ) 381 ): 382 expression.set( 383 "this", 384 ( 385 expression.this.upper() 386 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 387 else expression.this.lower() 388 ), 389 ) 390 391 return expression 392 393 def case_sensitive(self, text: str) -> bool: 394 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 395 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 396 return False 397 398 unsafe = ( 399 str.islower 400 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 401 else str.isupper 402 ) 403 return any(unsafe(char) for char in text) 404 405 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 406 """Checks if text can be identified given an identify option. 407 408 Args: 409 text: The text to check. 410 identify: 411 `"always"` or `True`: Always returns `True`. 412 `"safe"`: Only returns `True` if the identifier is case-insensitive. 413 414 Returns: 415 Whether or not the given text can be identified. 416 """ 417 if identify is True or identify == "always": 418 return True 419 420 if identify == "safe": 421 return not self.case_sensitive(text) 422 423 return False 424 425 def quote_identifier(self, expression: E, identify: bool = True) -> E: 426 """ 427 Adds quotes to a given identifier. 428 429 Args: 430 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 431 identify: If set to `False`, the quotes will only be added if the identifier is deemed 432 "unsafe", with respect to its characters and this dialect's normalization strategy. 433 """ 434 if isinstance(expression, exp.Identifier): 435 name = expression.this 436 expression.set( 437 "quoted", 438 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 439 ) 440 441 return expression 442 443 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 444 return self.parser(**opts).parse(self.tokenize(sql), sql) 445 446 def parse_into( 447 self, expression_type: exp.IntoType, sql: str, **opts 448 ) -> t.List[t.Optional[exp.Expression]]: 449 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 450 451 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 452 return self.generator(**opts).generate(expression, copy=copy) 453 454 def transpile(self, sql: str, **opts) -> t.List[str]: 455 return [ 456 self.generate(expression, copy=False, **opts) if expression else "" 457 for expression in self.parse(sql) 458 ] 459 460 def tokenize(self, sql: str) -> t.List[Token]: 461 return self.tokenizer.tokenize(sql) 462 463 @property 464 def tokenizer(self) -> Tokenizer: 465 if not hasattr(self, "_tokenizer"): 466 self._tokenizer = self.tokenizer_class(dialect=self) 467 return self._tokenizer 468 469 def parser(self, **opts) -> Parser: 470 return self.parser_class(dialect=self, **opts) 471 472 def generator(self, **opts) -> Generator: 473 return self.generator_class(dialect=self, **opts) 474 475 476DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 477 478 479def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 480 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 481 482 483def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 484 if expression.args.get("accuracy"): 485 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 486 return self.func("APPROX_COUNT_DISTINCT", expression.this) 487 488 489def if_sql( 490 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 491) -> t.Callable[[Generator, exp.If], str]: 492 def _if_sql(self: Generator, expression: exp.If) -> str: 493 return self.func( 494 name, 495 expression.this, 496 expression.args.get("true"), 497 expression.args.get("false") or false_value, 498 ) 499 500 return _if_sql 501 502 503def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 504 return self.binary(expression, "->") 505 506 507def arrow_json_extract_scalar_sql( 508 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 509) -> str: 510 return self.binary(expression, "->>") 511 512 513def inline_array_sql(self: Generator, expression: exp.Array) -> str: 514 return f"[{self.expressions(expression, flat=True)}]" 515 516 517def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 518 return self.like_sql( 519 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 520 ) 521 522 523def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 524 zone = self.sql(expression, "this") 525 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 526 527 528def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 529 if expression.args.get("recursive"): 530 self.unsupported("Recursive CTEs are unsupported") 531 expression.args["recursive"] = False 532 return self.with_sql(expression) 533 534 535def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 536 n = self.sql(expression, "this") 537 d = self.sql(expression, "expression") 538 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 539 540 541def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 542 self.unsupported("TABLESAMPLE unsupported") 543 return self.sql(expression.this) 544 545 546def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 547 self.unsupported("PIVOT unsupported") 548 return "" 549 550 551def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 552 return self.cast_sql(expression) 553 554 555def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 556 self.unsupported("Properties unsupported") 557 return "" 558 559 560def no_comment_column_constraint_sql( 561 self: Generator, expression: exp.CommentColumnConstraint 562) -> str: 563 self.unsupported("CommentColumnConstraint unsupported") 564 return "" 565 566 567def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 568 self.unsupported("MAP_FROM_ENTRIES unsupported") 569 return "" 570 571 572def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 573 this = self.sql(expression, "this") 574 substr = self.sql(expression, "substr") 575 position = self.sql(expression, "position") 576 if position: 577 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 578 return f"STRPOS({this}, {substr})" 579 580 581def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 582 return ( 583 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 584 ) 585 586 587def var_map_sql( 588 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 589) -> str: 590 keys = expression.args["keys"] 591 values = expression.args["values"] 592 593 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 594 self.unsupported("Cannot convert array columns into map.") 595 return self.func(map_func_name, keys, values) 596 597 args = [] 598 for key, value in zip(keys.expressions, values.expressions): 599 args.append(self.sql(key)) 600 args.append(self.sql(value)) 601 602 return self.func(map_func_name, *args) 603 604 605def format_time_lambda( 606 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 607) -> t.Callable[[t.List], E]: 608 """Helper used for time expressions. 609 610 Args: 611 exp_class: the expression class to instantiate. 612 dialect: target sql dialect. 613 default: the default format, True being time. 614 615 Returns: 616 A callable that can be used to return the appropriately formatted time expression. 617 """ 618 619 def _format_time(args: t.List): 620 return exp_class( 621 this=seq_get(args, 0), 622 format=Dialect[dialect].format_time( 623 seq_get(args, 1) 624 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 625 ), 626 ) 627 628 return _format_time 629 630 631def time_format( 632 dialect: DialectType = None, 633) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 634 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 635 """ 636 Returns the time format for a given expression, unless it's equivalent 637 to the default time format of the dialect of interest. 638 """ 639 time_format = self.format_time(expression) 640 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 641 642 return _time_format 643 644 645def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 646 """ 647 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 648 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 649 columns are removed from the create statement. 650 """ 651 has_schema = isinstance(expression.this, exp.Schema) 652 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 653 654 if has_schema and is_partitionable: 655 prop = expression.find(exp.PartitionedByProperty) 656 if prop and prop.this and not isinstance(prop.this, exp.Schema): 657 schema = expression.this 658 columns = {v.name.upper() for v in prop.this.expressions} 659 partitions = [col for col in schema.expressions if col.name.upper() in columns] 660 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 661 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 662 expression.set("this", schema) 663 664 return self.create_sql(expression) 665 666 667def parse_date_delta( 668 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 669) -> t.Callable[[t.List], E]: 670 def inner_func(args: t.List) -> E: 671 unit_based = len(args) == 3 672 this = args[2] if unit_based else seq_get(args, 0) 673 unit = args[0] if unit_based else exp.Literal.string("DAY") 674 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 675 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 676 677 return inner_func 678 679 680def parse_date_delta_with_interval( 681 expression_class: t.Type[E], 682) -> t.Callable[[t.List], t.Optional[E]]: 683 def func(args: t.List) -> t.Optional[E]: 684 if len(args) < 2: 685 return None 686 687 interval = args[1] 688 689 if not isinstance(interval, exp.Interval): 690 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 691 692 expression = interval.this 693 if expression and expression.is_string: 694 expression = exp.Literal.number(expression.this) 695 696 return expression_class( 697 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 698 ) 699 700 return func 701 702 703def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 704 unit = seq_get(args, 0) 705 this = seq_get(args, 1) 706 707 if isinstance(this, exp.Cast) and this.is_type("date"): 708 return exp.DateTrunc(unit=unit, this=this) 709 return exp.TimestampTrunc(this=this, unit=unit) 710 711 712def date_add_interval_sql( 713 data_type: str, kind: str 714) -> t.Callable[[Generator, exp.Expression], str]: 715 def func(self: Generator, expression: exp.Expression) -> str: 716 this = self.sql(expression, "this") 717 unit = expression.args.get("unit") 718 unit = exp.var(unit.name.upper() if unit else "DAY") 719 interval = exp.Interval(this=expression.expression, unit=unit) 720 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 721 722 return func 723 724 725def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 726 return self.func( 727 "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this 728 ) 729 730 731def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 732 if not expression.expression: 733 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 734 if expression.text("expression").lower() in TIMEZONES: 735 return self.sql( 736 exp.AtTimeZone( 737 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 738 zone=expression.expression, 739 ) 740 ) 741 return self.function_fallback_sql(expression) 742 743 744def locate_to_strposition(args: t.List) -> exp.Expression: 745 return exp.StrPosition( 746 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 747 ) 748 749 750def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 751 return self.func( 752 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 753 ) 754 755 756def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 757 return self.sql( 758 exp.Substring( 759 this=expression.this, start=exp.Literal.number(1), length=expression.expression 760 ) 761 ) 762 763 764def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 765 return self.sql( 766 exp.Substring( 767 this=expression.this, 768 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 769 ) 770 ) 771 772 773def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 774 return self.sql(exp.cast(expression.this, "timestamp")) 775 776 777def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 778 return self.sql(exp.cast(expression.this, "date")) 779 780 781# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 782def encode_decode_sql( 783 self: Generator, expression: exp.Expression, name: str, replace: bool = True 784) -> str: 785 charset = expression.args.get("charset") 786 if charset and charset.name.lower() != "utf-8": 787 self.unsupported(f"Expected utf-8 character set, got {charset}.") 788 789 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 790 791 792def min_or_least(self: Generator, expression: exp.Min) -> str: 793 name = "LEAST" if expression.expressions else "MIN" 794 return rename_func(name)(self, expression) 795 796 797def max_or_greatest(self: Generator, expression: exp.Max) -> str: 798 name = "GREATEST" if expression.expressions else "MAX" 799 return rename_func(name)(self, expression) 800 801 802def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 803 cond = expression.this 804 805 if isinstance(expression.this, exp.Distinct): 806 cond = expression.this.expressions[0] 807 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 808 809 return self.func("sum", exp.func("if", cond, 1, 0)) 810 811 812def trim_sql(self: Generator, expression: exp.Trim) -> str: 813 target = self.sql(expression, "this") 814 trim_type = self.sql(expression, "position") 815 remove_chars = self.sql(expression, "expression") 816 collation = self.sql(expression, "collation") 817 818 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 819 if not remove_chars and not collation: 820 return self.trim_sql(expression) 821 822 trim_type = f"{trim_type} " if trim_type else "" 823 remove_chars = f"{remove_chars} " if remove_chars else "" 824 from_part = "FROM " if trim_type or remove_chars else "" 825 collation = f" COLLATE {collation}" if collation else "" 826 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 827 828 829def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 830 return self.func("STRPTIME", expression.this, self.format_time(expression)) 831 832 833def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 834 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 835 836 837def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 838 delim, *rest_args = expression.expressions 839 return self.sql( 840 reduce( 841 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 842 rest_args, 843 ) 844 ) 845 846 847def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 848 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 849 if bad_args: 850 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 851 852 return self.func( 853 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 854 ) 855 856 857def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 858 bad_args = list( 859 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 860 ) 861 if bad_args: 862 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 863 864 return self.func( 865 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 866 ) 867 868 869def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 870 names = [] 871 for agg in aggregations: 872 if isinstance(agg, exp.Alias): 873 names.append(agg.alias) 874 else: 875 """ 876 This case corresponds to aggregations without aliases being used as suffixes 877 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 878 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 879 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 880 """ 881 agg_all_unquoted = agg.transform( 882 lambda node: ( 883 exp.Identifier(this=node.name, quoted=False) 884 if isinstance(node, exp.Identifier) 885 else node 886 ) 887 ) 888 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 889 890 return names 891 892 893def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 894 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 895 896 897# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 898def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 899 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 900 901 902def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 903 return self.func("MAX", expression.this) 904 905 906def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 907 a = self.sql(expression.left) 908 b = self.sql(expression.right) 909 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 910 911 912def is_parse_json(expression: exp.Expression) -> bool: 913 return isinstance(expression, exp.ParseJSON) or ( 914 isinstance(expression, exp.Cast) and expression.is_type("json") 915 ) 916 917 918def isnull_to_is_null(args: t.List) -> exp.Expression: 919 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 920 921 922def generatedasidentitycolumnconstraint_sql( 923 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 924) -> str: 925 start = self.sql(expression, "start") or "1" 926 increment = self.sql(expression, "increment") or "1" 927 return f"IDENTITY({start}, {increment})" 928 929 930def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 931 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 932 if expression.args.get("count"): 933 self.unsupported(f"Only two arguments are supported in function {name}.") 934 935 return self.func(name, expression.this, expression.expression) 936 937 return _arg_max_or_min_sql 938 939 940def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 941 this = expression.this.copy() 942 943 return_type = expression.return_type 944 if return_type.is_type(exp.DataType.Type.DATE): 945 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 946 # can truncate timestamp strings, because some dialects can't cast them to DATE 947 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 948 949 expression.this.replace(exp.cast(this, return_type)) 950 return expression 951 952 953def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 954 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 955 if cast and isinstance(expression, exp.TsOrDsAdd): 956 expression = ts_or_ds_add_cast(expression) 957 958 return self.func( 959 name, 960 exp.var(expression.text("unit").upper() or "DAY"), 961 expression.expression, 962 expression.this, 963 ) 964 965 return _delta_sql 966 967 968def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: 969 from sqlglot.optimizer.simplify import simplify 970 971 # Makes sure the path will be evaluated correctly at runtime to include the path root. 972 # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. 973 path = expression.expression 974 path = exp.func( 975 "if", 976 exp.func("startswith", path, "'['"), 977 exp.func("concat", "'$'", path), 978 exp.func("concat", "'$.'", path), 979 ) 980 981 expression.expression.replace(simplify(path)) 982 return expression 983 984 985def path_to_jsonpath( 986 name: str = "JSON_EXTRACT", 987) -> t.Callable[[Generator, exp.GetPath], str]: 988 def _transform(self: Generator, expression: exp.GetPath) -> str: 989 return rename_func(name)(self, prepend_dollar_to_path(expression)) 990 991 return _transform 992 993 994def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 995 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 996 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 997 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 998 999 return self.sql(exp.cast(minus_one_day, "date")) 1000 1001 1002def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1003 """Remove table refs from columns in when statements.""" 1004 alias = expression.this.args.get("alias") 1005 1006 normalize = lambda identifier: ( 1007 self.dialect.normalize_identifier(identifier).name if identifier else None 1008 ) 1009 1010 targets = {normalize(expression.this.this)} 1011 1012 if alias: 1013 targets.add(normalize(alias.this)) 1014 1015 for when in expression.expressions: 1016 when.transform( 1017 lambda node: ( 1018 exp.column(node.this) 1019 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1020 else node 1021 ), 1022 copy=False, 1023 ) 1024 1025 return self.merge_sql(expression)
24class Dialects(str, Enum): 25 """Dialects supported by SQLGLot.""" 26 27 DIALECT = "" 28 29 BIGQUERY = "bigquery" 30 CLICKHOUSE = "clickhouse" 31 DATABRICKS = "databricks" 32 DORIS = "doris" 33 DRILL = "drill" 34 DUCKDB = "duckdb" 35 HIVE = "hive" 36 MYSQL = "mysql" 37 ORACLE = "oracle" 38 POSTGRES = "postgres" 39 PRESTO = "presto" 40 REDSHIFT = "redshift" 41 SNOWFLAKE = "snowflake" 42 SPARK = "spark" 43 SPARK2 = "spark2" 44 SQLITE = "sqlite" 45 STARROCKS = "starrocks" 46 TABLEAU = "tableau" 47 TERADATA = "teradata" 48 TRINO = "trino" 49 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
52class NormalizationStrategy(str, AutoName): 53 """Specifies the strategy according to which identifiers should be normalized.""" 54 55 LOWERCASE = auto() 56 """Unquoted identifiers are lowercased.""" 57 58 UPPERCASE = auto() 59 """Unquoted identifiers are uppercased.""" 60 61 CASE_SENSITIVE = auto() 62 """Always case-sensitive, regardless of quotes.""" 63 64 CASE_INSENSITIVE = auto() 65 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
144class Dialect(metaclass=_Dialect): 145 INDEX_OFFSET = 0 146 """Determines the base index offset for arrays.""" 147 148 WEEK_OFFSET = 0 149 """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 150 151 UNNEST_COLUMN_ONLY = False 152 """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" 153 154 ALIAS_POST_TABLESAMPLE = False 155 """Determines whether or not the table alias comes after tablesample.""" 156 157 TABLESAMPLE_SIZE_IS_PERCENT = False 158 """Determines whether or not a size in the table sample clause represents percentage.""" 159 160 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 161 """Specifies the strategy according to which identifiers should be normalized.""" 162 163 IDENTIFIERS_CAN_START_WITH_DIGIT = False 164 """Determines whether or not an unquoted identifier can start with a digit.""" 165 166 DPIPE_IS_STRING_CONCAT = True 167 """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" 168 169 STRICT_STRING_CONCAT = False 170 """Determines whether or not `CONCAT`'s arguments must be strings.""" 171 172 SUPPORTS_USER_DEFINED_TYPES = True 173 """Determines whether or not user-defined data types are supported.""" 174 175 SUPPORTS_SEMI_ANTI_JOIN = True 176 """Determines whether or not `SEMI` or `ANTI` joins are supported.""" 177 178 NORMALIZE_FUNCTIONS: bool | str = "upper" 179 """Determines how function names are going to be normalized.""" 180 181 LOG_BASE_FIRST = True 182 """Determines whether the base comes first in the `LOG` function.""" 183 184 NULL_ORDERING = "nulls_are_small" 185 """ 186 Indicates the default `NULL` ordering method to use if not explicitly set. 187 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 188 """ 189 190 TYPED_DIVISION = False 191 """ 192 Whether the behavior of `a / b` depends on the types of `a` and `b`. 193 False means `a / b` is always float division. 194 True means `a / b` is integer division if both `a` and `b` are integers. 195 """ 196 197 SAFE_DIVISION = False 198 """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" 199 200 CONCAT_COALESCE = False 201 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 202 203 DATE_FORMAT = "'%Y-%m-%d'" 204 DATEINT_FORMAT = "'%Y%m%d'" 205 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 206 207 TIME_MAPPING: t.Dict[str, str] = {} 208 """Associates this dialect's time formats with their equivalent Python `strftime` format.""" 209 210 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 211 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 212 FORMAT_MAPPING: t.Dict[str, str] = {} 213 """ 214 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 215 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 216 """ 217 218 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 219 """Mapping of an unescaped escape sequence to the corresponding character.""" 220 221 PSEUDOCOLUMNS: t.Set[str] = set() 222 """ 223 Columns that are auto-generated by the engine corresponding to this dialect. 224 For example, such columns may be excluded from `SELECT *` queries. 225 """ 226 227 PREFER_CTE_ALIAS_COLUMN = False 228 """ 229 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 230 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 231 any projection aliases in the subquery. 232 233 For example, 234 WITH y(c) AS ( 235 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 236 ) SELECT c FROM y; 237 238 will be rewritten as 239 240 WITH y(c) AS ( 241 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 242 ) SELECT c FROM y; 243 """ 244 245 # --- Autofilled --- 246 247 tokenizer_class = Tokenizer 248 parser_class = Parser 249 generator_class = Generator 250 251 # A trie of the time_mapping keys 252 TIME_TRIE: t.Dict = {} 253 FORMAT_TRIE: t.Dict = {} 254 255 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 256 INVERSE_TIME_TRIE: t.Dict = {} 257 258 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 259 260 # Delimiters for quotes, identifiers and the corresponding escape characters 261 QUOTE_START = "'" 262 QUOTE_END = "'" 263 IDENTIFIER_START = '"' 264 IDENTIFIER_END = '"' 265 266 # Delimiters for bit, hex, byte and unicode literals 267 BIT_START: t.Optional[str] = None 268 BIT_END: t.Optional[str] = None 269 HEX_START: t.Optional[str] = None 270 HEX_END: t.Optional[str] = None 271 BYTE_START: t.Optional[str] = None 272 BYTE_END: t.Optional[str] = None 273 UNICODE_START: t.Optional[str] = None 274 UNICODE_END: t.Optional[str] = None 275 276 @classmethod 277 def get_or_raise(cls, dialect: DialectType) -> Dialect: 278 """ 279 Look up a dialect in the global dialect registry and return it if it exists. 280 281 Args: 282 dialect: The target dialect. If this is a string, it can be optionally followed by 283 additional key-value pairs that are separated by commas and are used to specify 284 dialect settings, such as whether the dialect's identifiers are case-sensitive. 285 286 Example: 287 >>> dialect = dialect_class = get_or_raise("duckdb") 288 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 289 290 Returns: 291 The corresponding Dialect instance. 292 """ 293 294 if not dialect: 295 return cls() 296 if isinstance(dialect, _Dialect): 297 return dialect() 298 if isinstance(dialect, Dialect): 299 return dialect 300 if isinstance(dialect, str): 301 try: 302 dialect_name, *kv_pairs = dialect.split(",") 303 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 304 except ValueError: 305 raise ValueError( 306 f"Invalid dialect format: '{dialect}'. " 307 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 308 ) 309 310 result = cls.get(dialect_name.strip()) 311 if not result: 312 from difflib import get_close_matches 313 314 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 315 if similar: 316 similar = f" Did you mean {similar}?" 317 318 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 319 320 return result(**kwargs) 321 322 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 323 324 @classmethod 325 def format_time( 326 cls, expression: t.Optional[str | exp.Expression] 327 ) -> t.Optional[exp.Expression]: 328 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 329 if isinstance(expression, str): 330 return exp.Literal.string( 331 # the time formats are quoted 332 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 333 ) 334 335 if expression and expression.is_string: 336 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 337 338 return expression 339 340 def __init__(self, **kwargs) -> None: 341 normalization_strategy = kwargs.get("normalization_strategy") 342 343 if normalization_strategy is None: 344 self.normalization_strategy = self.NORMALIZATION_STRATEGY 345 else: 346 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 347 348 def __eq__(self, other: t.Any) -> bool: 349 # Does not currently take dialect state into account 350 return type(self) == other 351 352 def __hash__(self) -> int: 353 # Does not currently take dialect state into account 354 return hash(type(self)) 355 356 def normalize_identifier(self, expression: E) -> E: 357 """ 358 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 359 360 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 361 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 362 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 363 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 364 365 There are also dialects like Spark, which are case-insensitive even when quotes are 366 present, and dialects like MySQL, whose resolution rules match those employed by the 367 underlying operating system, for example they may always be case-sensitive in Linux. 368 369 Finally, the normalization behavior of some engines can even be controlled through flags, 370 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 371 372 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 373 that it can analyze queries in the optimizer and successfully capture their semantics. 374 """ 375 if ( 376 isinstance(expression, exp.Identifier) 377 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 378 and ( 379 not expression.quoted 380 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 381 ) 382 ): 383 expression.set( 384 "this", 385 ( 386 expression.this.upper() 387 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 388 else expression.this.lower() 389 ), 390 ) 391 392 return expression 393 394 def case_sensitive(self, text: str) -> bool: 395 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 396 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 397 return False 398 399 unsafe = ( 400 str.islower 401 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 402 else str.isupper 403 ) 404 return any(unsafe(char) for char in text) 405 406 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 407 """Checks if text can be identified given an identify option. 408 409 Args: 410 text: The text to check. 411 identify: 412 `"always"` or `True`: Always returns `True`. 413 `"safe"`: Only returns `True` if the identifier is case-insensitive. 414 415 Returns: 416 Whether or not the given text can be identified. 417 """ 418 if identify is True or identify == "always": 419 return True 420 421 if identify == "safe": 422 return not self.case_sensitive(text) 423 424 return False 425 426 def quote_identifier(self, expression: E, identify: bool = True) -> E: 427 """ 428 Adds quotes to a given identifier. 429 430 Args: 431 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 432 identify: If set to `False`, the quotes will only be added if the identifier is deemed 433 "unsafe", with respect to its characters and this dialect's normalization strategy. 434 """ 435 if isinstance(expression, exp.Identifier): 436 name = expression.this 437 expression.set( 438 "quoted", 439 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 440 ) 441 442 return expression 443 444 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 445 return self.parser(**opts).parse(self.tokenize(sql), sql) 446 447 def parse_into( 448 self, expression_type: exp.IntoType, sql: str, **opts 449 ) -> t.List[t.Optional[exp.Expression]]: 450 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 451 452 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 453 return self.generator(**opts).generate(expression, copy=copy) 454 455 def transpile(self, sql: str, **opts) -> t.List[str]: 456 return [ 457 self.generate(expression, copy=False, **opts) if expression else "" 458 for expression in self.parse(sql) 459 ] 460 461 def tokenize(self, sql: str) -> t.List[Token]: 462 return self.tokenizer.tokenize(sql) 463 464 @property 465 def tokenizer(self) -> Tokenizer: 466 if not hasattr(self, "_tokenizer"): 467 self._tokenizer = self.tokenizer_class(dialect=self) 468 return self._tokenizer 469 470 def parser(self, **opts) -> Parser: 471 return self.parser_class(dialect=self, **opts) 472 473 def generator(self, **opts) -> Generator: 474 return self.generator_class(dialect=self, **opts)
340 def __init__(self, **kwargs) -> None: 341 normalization_strategy = kwargs.get("normalization_strategy") 342 343 if normalization_strategy is None: 344 self.normalization_strategy = self.NORMALIZATION_STRATEGY 345 else: 346 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Determines whether or not UNNEST
table aliases are treated as column aliases.
Determines whether or not a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines whether or not an unquoted identifier can start with a digit.
Determines whether or not the DPIPE token (||
) is a string concatenation operator.
Indicates the default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
Determines whether division by zero throws an error (False
) or returns NULL (True
).
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
format.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an unescaped escape sequence to the corresponding character.
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
276 @classmethod 277 def get_or_raise(cls, dialect: DialectType) -> Dialect: 278 """ 279 Look up a dialect in the global dialect registry and return it if it exists. 280 281 Args: 282 dialect: The target dialect. If this is a string, it can be optionally followed by 283 additional key-value pairs that are separated by commas and are used to specify 284 dialect settings, such as whether the dialect's identifiers are case-sensitive. 285 286 Example: 287 >>> dialect = dialect_class = get_or_raise("duckdb") 288 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 289 290 Returns: 291 The corresponding Dialect instance. 292 """ 293 294 if not dialect: 295 return cls() 296 if isinstance(dialect, _Dialect): 297 return dialect() 298 if isinstance(dialect, Dialect): 299 return dialect 300 if isinstance(dialect, str): 301 try: 302 dialect_name, *kv_pairs = dialect.split(",") 303 kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} 304 except ValueError: 305 raise ValueError( 306 f"Invalid dialect format: '{dialect}'. " 307 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 308 ) 309 310 result = cls.get(dialect_name.strip()) 311 if not result: 312 from difflib import get_close_matches 313 314 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 315 if similar: 316 similar = f" Did you mean {similar}?" 317 318 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 319 320 return result(**kwargs) 321 322 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
324 @classmethod 325 def format_time( 326 cls, expression: t.Optional[str | exp.Expression] 327 ) -> t.Optional[exp.Expression]: 328 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 329 if isinstance(expression, str): 330 return exp.Literal.string( 331 # the time formats are quoted 332 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 333 ) 334 335 if expression and expression.is_string: 336 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 337 338 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
356 def normalize_identifier(self, expression: E) -> E: 357 """ 358 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 359 360 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 361 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 362 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 363 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 364 365 There are also dialects like Spark, which are case-insensitive even when quotes are 366 present, and dialects like MySQL, whose resolution rules match those employed by the 367 underlying operating system, for example they may always be case-sensitive in Linux. 368 369 Finally, the normalization behavior of some engines can even be controlled through flags, 370 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 371 372 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 373 that it can analyze queries in the optimizer and successfully capture their semantics. 374 """ 375 if ( 376 isinstance(expression, exp.Identifier) 377 and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE 378 and ( 379 not expression.quoted 380 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 381 ) 382 ): 383 expression.set( 384 "this", 385 ( 386 expression.this.upper() 387 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 388 else expression.this.lower() 389 ), 390 ) 391 392 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
394 def case_sensitive(self, text: str) -> bool: 395 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 396 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 397 return False 398 399 unsafe = ( 400 str.islower 401 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 402 else str.isupper 403 ) 404 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
406 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 407 """Checks if text can be identified given an identify option. 408 409 Args: 410 text: The text to check. 411 identify: 412 `"always"` or `True`: Always returns `True`. 413 `"safe"`: Only returns `True` if the identifier is case-insensitive. 414 415 Returns: 416 Whether or not the given text can be identified. 417 """ 418 if identify is True or identify == "always": 419 return True 420 421 if identify == "safe": 422 return not self.case_sensitive(text) 423 424 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
426 def quote_identifier(self, expression: E, identify: bool = True) -> E: 427 """ 428 Adds quotes to a given identifier. 429 430 Args: 431 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 432 identify: If set to `False`, the quotes will only be added if the identifier is deemed 433 "unsafe", with respect to its characters and this dialect's normalization strategy. 434 """ 435 if isinstance(expression, exp.Identifier): 436 name = expression.this 437 expression.set( 438 "quoted", 439 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 440 ) 441 442 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
490def if_sql( 491 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 492) -> t.Callable[[Generator, exp.If], str]: 493 def _if_sql(self: Generator, expression: exp.If) -> str: 494 return self.func( 495 name, 496 expression.this, 497 expression.args.get("true"), 498 expression.args.get("false") or false_value, 499 ) 500 501 return _if_sql
573def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 574 this = self.sql(expression, "this") 575 substr = self.sql(expression, "substr") 576 position = self.sql(expression, "position") 577 if position: 578 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 579 return f"STRPOS({this}, {substr})"
588def var_map_sql( 589 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 590) -> str: 591 keys = expression.args["keys"] 592 values = expression.args["values"] 593 594 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 595 self.unsupported("Cannot convert array columns into map.") 596 return self.func(map_func_name, keys, values) 597 598 args = [] 599 for key, value in zip(keys.expressions, values.expressions): 600 args.append(self.sql(key)) 601 args.append(self.sql(value)) 602 603 return self.func(map_func_name, *args)
606def format_time_lambda( 607 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 608) -> t.Callable[[t.List], E]: 609 """Helper used for time expressions. 610 611 Args: 612 exp_class: the expression class to instantiate. 613 dialect: target sql dialect. 614 default: the default format, True being time. 615 616 Returns: 617 A callable that can be used to return the appropriately formatted time expression. 618 """ 619 620 def _format_time(args: t.List): 621 return exp_class( 622 this=seq_get(args, 0), 623 format=Dialect[dialect].format_time( 624 seq_get(args, 1) 625 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 626 ), 627 ) 628 629 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
632def time_format( 633 dialect: DialectType = None, 634) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 635 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 636 """ 637 Returns the time format for a given expression, unless it's equivalent 638 to the default time format of the dialect of interest. 639 """ 640 time_format = self.format_time(expression) 641 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 642 643 return _time_format
646def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 647 """ 648 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 649 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 650 columns are removed from the create statement. 651 """ 652 has_schema = isinstance(expression.this, exp.Schema) 653 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 654 655 if has_schema and is_partitionable: 656 prop = expression.find(exp.PartitionedByProperty) 657 if prop and prop.this and not isinstance(prop.this, exp.Schema): 658 schema = expression.this 659 columns = {v.name.upper() for v in prop.this.expressions} 660 partitions = [col for col in schema.expressions if col.name.upper() in columns] 661 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 662 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 663 expression.set("this", schema) 664 665 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
668def parse_date_delta( 669 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 670) -> t.Callable[[t.List], E]: 671 def inner_func(args: t.List) -> E: 672 unit_based = len(args) == 3 673 this = args[2] if unit_based else seq_get(args, 0) 674 unit = args[0] if unit_based else exp.Literal.string("DAY") 675 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 676 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 677 678 return inner_func
681def parse_date_delta_with_interval( 682 expression_class: t.Type[E], 683) -> t.Callable[[t.List], t.Optional[E]]: 684 def func(args: t.List) -> t.Optional[E]: 685 if len(args) < 2: 686 return None 687 688 interval = args[1] 689 690 if not isinstance(interval, exp.Interval): 691 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 692 693 expression = interval.this 694 if expression and expression.is_string: 695 expression = exp.Literal.number(expression.this) 696 697 return expression_class( 698 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 699 ) 700 701 return func
713def date_add_interval_sql( 714 data_type: str, kind: str 715) -> t.Callable[[Generator, exp.Expression], str]: 716 def func(self: Generator, expression: exp.Expression) -> str: 717 this = self.sql(expression, "this") 718 unit = expression.args.get("unit") 719 unit = exp.var(unit.name.upper() if unit else "DAY") 720 interval = exp.Interval(this=expression.expression, unit=unit) 721 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 722 723 return func
732def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 733 if not expression.expression: 734 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 735 if expression.text("expression").lower() in TIMEZONES: 736 return self.sql( 737 exp.AtTimeZone( 738 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 739 zone=expression.expression, 740 ) 741 ) 742 return self.function_fallback_sql(expression)
783def encode_decode_sql( 784 self: Generator, expression: exp.Expression, name: str, replace: bool = True 785) -> str: 786 charset = expression.args.get("charset") 787 if charset and charset.name.lower() != "utf-8": 788 self.unsupported(f"Expected utf-8 character set, got {charset}.") 789 790 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
803def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 804 cond = expression.this 805 806 if isinstance(expression.this, exp.Distinct): 807 cond = expression.this.expressions[0] 808 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 809 810 return self.func("sum", exp.func("if", cond, 1, 0))
813def trim_sql(self: Generator, expression: exp.Trim) -> str: 814 target = self.sql(expression, "this") 815 trim_type = self.sql(expression, "position") 816 remove_chars = self.sql(expression, "expression") 817 collation = self.sql(expression, "collation") 818 819 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 820 if not remove_chars and not collation: 821 return self.trim_sql(expression) 822 823 trim_type = f"{trim_type} " if trim_type else "" 824 remove_chars = f"{remove_chars} " if remove_chars else "" 825 from_part = "FROM " if trim_type or remove_chars else "" 826 collation = f" COLLATE {collation}" if collation else "" 827 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
848def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 849 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 850 if bad_args: 851 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 852 853 return self.func( 854 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 855 )
858def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 859 bad_args = list( 860 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 861 ) 862 if bad_args: 863 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 864 865 return self.func( 866 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 867 )
870def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 871 names = [] 872 for agg in aggregations: 873 if isinstance(agg, exp.Alias): 874 names.append(agg.alias) 875 else: 876 """ 877 This case corresponds to aggregations without aliases being used as suffixes 878 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 879 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 880 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 881 """ 882 agg_all_unquoted = agg.transform( 883 lambda node: ( 884 exp.Identifier(this=node.name, quoted=False) 885 if isinstance(node, exp.Identifier) 886 else node 887 ) 888 ) 889 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 890 891 return names
931def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 932 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 933 if expression.args.get("count"): 934 self.unsupported(f"Only two arguments are supported in function {name}.") 935 936 return self.func(name, expression.this, expression.expression) 937 938 return _arg_max_or_min_sql
941def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 942 this = expression.this.copy() 943 944 return_type = expression.return_type 945 if return_type.is_type(exp.DataType.Type.DATE): 946 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 947 # can truncate timestamp strings, because some dialects can't cast them to DATE 948 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 949 950 expression.this.replace(exp.cast(this, return_type)) 951 return expression
954def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 955 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 956 if cast and isinstance(expression, exp.TsOrDsAdd): 957 expression = ts_or_ds_add_cast(expression) 958 959 return self.func( 960 name, 961 exp.var(expression.text("unit").upper() or "DAY"), 962 expression.expression, 963 expression.this, 964 ) 965 966 return _delta_sql
969def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: 970 from sqlglot.optimizer.simplify import simplify 971 972 # Makes sure the path will be evaluated correctly at runtime to include the path root. 973 # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. 974 path = expression.expression 975 path = exp.func( 976 "if", 977 exp.func("startswith", path, "'['"), 978 exp.func("concat", "'$'", path), 979 exp.func("concat", "'$.'", path), 980 ) 981 982 expression.expression.replace(simplify(path)) 983 return expression
995def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 996 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 997 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 998 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 999 1000 return self.sql(exp.cast(minus_one_day, "date"))
1003def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1004 """Remove table refs from columns in when statements.""" 1005 alias = expression.this.args.get("alias") 1006 1007 normalize = lambda identifier: ( 1008 self.dialect.normalize_identifier(identifier).name if identifier else None 1009 ) 1010 1011 targets = {normalize(expression.this.this)} 1012 1013 if alias: 1014 targets.add(normalize(alias.this)) 1015 1016 for when in expression.expressions: 1017 when.transform( 1018 lambda node: ( 1019 exp.column(node.this) 1020 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1021 else node 1022 ), 1023 copy=False, 1024 ) 1025 1026 return self.merge_sql(expression)
Remove table refs from columns in when statements.