sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "clickhouse"): 191 klass.generator_class.SUPPORTS_NULLABLE_TYPES = False 192 193 if enum not in ("", "athena", "presto", "trino"): 194 klass.generator_class.TRY_SUPPORTED = False 195 klass.generator_class.SUPPORTS_UESCAPE = False 196 197 if enum not in ("", "databricks", "hive", "spark", "spark2"): 198 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 199 for modifier in ("cluster", "distribute", "sort"): 200 modifier_transforms.pop(modifier, None) 201 202 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 203 204 if enum not in ("", "doris", "mysql"): 205 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 209 TokenType.STRAIGHT_JOIN, 210 } 211 212 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 213 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 214 TokenType.ANTI, 215 TokenType.SEMI, 216 } 217 218 return klass 219 220 221class Dialect(metaclass=_Dialect): 222 INDEX_OFFSET = 0 223 """The base index offset for arrays.""" 224 225 WEEK_OFFSET = 0 226 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 227 228 UNNEST_COLUMN_ONLY = False 229 """Whether `UNNEST` table aliases are treated as column aliases.""" 230 231 ALIAS_POST_TABLESAMPLE = False 232 """Whether the table alias comes after tablesample.""" 233 234 TABLESAMPLE_SIZE_IS_PERCENT = False 235 """Whether a size in the table sample clause represents percentage.""" 236 237 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 238 """Specifies the strategy according to which identifiers should be normalized.""" 239 240 IDENTIFIERS_CAN_START_WITH_DIGIT = False 241 """Whether an unquoted identifier can start with a digit.""" 242 243 DPIPE_IS_STRING_CONCAT = True 244 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 245 246 STRICT_STRING_CONCAT = False 247 """Whether `CONCAT`'s arguments must be strings.""" 248 249 SUPPORTS_USER_DEFINED_TYPES = True 250 """Whether user-defined data types are supported.""" 251 252 SUPPORTS_SEMI_ANTI_JOIN = True 253 """Whether `SEMI` or `ANTI` joins are supported.""" 254 255 SUPPORTS_COLUMN_JOIN_MARKS = False 256 """Whether the old-style outer join (+) syntax is supported.""" 257 258 COPY_PARAMS_ARE_CSV = True 259 """Separator of COPY statement parameters.""" 260 261 NORMALIZE_FUNCTIONS: bool | str = "upper" 262 """ 263 Determines how function names are going to be normalized. 264 Possible values: 265 "upper" or True: Convert names to uppercase. 266 "lower": Convert names to lowercase. 267 False: Disables function name normalization. 268 """ 269 270 LOG_BASE_FIRST: t.Optional[bool] = True 271 """ 272 Whether the base comes first in the `LOG` function. 273 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 274 """ 275 276 NULL_ORDERING = "nulls_are_small" 277 """ 278 Default `NULL` ordering method to use if not explicitly set. 279 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 280 """ 281 282 TYPED_DIVISION = False 283 """ 284 Whether the behavior of `a / b` depends on the types of `a` and `b`. 285 False means `a / b` is always float division. 286 True means `a / b` is integer division if both `a` and `b` are integers. 287 """ 288 289 SAFE_DIVISION = False 290 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 291 292 CONCAT_COALESCE = False 293 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 294 295 HEX_LOWERCASE = False 296 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 297 298 DATE_FORMAT = "'%Y-%m-%d'" 299 DATEINT_FORMAT = "'%Y%m%d'" 300 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 301 302 TIME_MAPPING: t.Dict[str, str] = {} 303 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 304 305 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 306 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 307 FORMAT_MAPPING: t.Dict[str, str] = {} 308 """ 309 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 310 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 311 """ 312 313 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 314 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 315 316 PSEUDOCOLUMNS: t.Set[str] = set() 317 """ 318 Columns that are auto-generated by the engine corresponding to this dialect. 319 For example, such columns may be excluded from `SELECT *` queries. 320 """ 321 322 PREFER_CTE_ALIAS_COLUMN = False 323 """ 324 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 325 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 326 any projection aliases in the subquery. 327 328 For example, 329 WITH y(c) AS ( 330 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 333 will be rewritten as 334 335 WITH y(c) AS ( 336 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 337 ) SELECT c FROM y; 338 """ 339 340 COPY_PARAMS_ARE_CSV = True 341 """ 342 Whether COPY statement parameters are separated by comma or whitespace 343 """ 344 345 FORCE_EARLY_ALIAS_REF_EXPANSION = False 346 """ 347 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 348 349 For example: 350 WITH data AS ( 351 SELECT 352 1 AS id, 353 2 AS my_id 354 ) 355 SELECT 356 id AS my_id 357 FROM 358 data 359 WHERE 360 my_id = 1 361 GROUP BY 362 my_id, 363 HAVING 364 my_id = 1 365 366 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 367 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 368 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 369 """ 370 371 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 372 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 373 374 SUPPORTS_ORDER_BY_ALL = False 375 """ 376 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 377 """ 378 379 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 380 """ 381 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 382 as the former is of type INT[] vs the latter which is SUPER 383 """ 384 385 SUPPORTS_FIXED_SIZE_ARRAYS = False 386 """ 387 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 388 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 389 """ 390 391 CREATABLE_KIND_MAPPING: dict[str, str] = {} 392 """ 393 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 394 equivalent of CREATE SCHEMA is CREATE DATABASE. 395 """ 396 397 # --- Autofilled --- 398 399 tokenizer_class = Tokenizer 400 jsonpath_tokenizer_class = JSONPathTokenizer 401 parser_class = Parser 402 generator_class = Generator 403 404 # A trie of the time_mapping keys 405 TIME_TRIE: t.Dict = {} 406 FORMAT_TRIE: t.Dict = {} 407 408 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 409 INVERSE_TIME_TRIE: t.Dict = {} 410 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 411 INVERSE_FORMAT_TRIE: t.Dict = {} 412 413 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 414 415 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 416 417 # Delimiters for string literals and identifiers 418 QUOTE_START = "'" 419 QUOTE_END = "'" 420 IDENTIFIER_START = '"' 421 IDENTIFIER_END = '"' 422 423 # Delimiters for bit, hex, byte and unicode literals 424 BIT_START: t.Optional[str] = None 425 BIT_END: t.Optional[str] = None 426 HEX_START: t.Optional[str] = None 427 HEX_END: t.Optional[str] = None 428 BYTE_START: t.Optional[str] = None 429 BYTE_END: t.Optional[str] = None 430 UNICODE_START: t.Optional[str] = None 431 UNICODE_END: t.Optional[str] = None 432 433 DATE_PART_MAPPING = { 434 "Y": "YEAR", 435 "YY": "YEAR", 436 "YYY": "YEAR", 437 "YYYY": "YEAR", 438 "YR": "YEAR", 439 "YEARS": "YEAR", 440 "YRS": "YEAR", 441 "MM": "MONTH", 442 "MON": "MONTH", 443 "MONS": "MONTH", 444 "MONTHS": "MONTH", 445 "D": "DAY", 446 "DD": "DAY", 447 "DAYS": "DAY", 448 "DAYOFMONTH": "DAY", 449 "DAY OF WEEK": "DAYOFWEEK", 450 "WEEKDAY": "DAYOFWEEK", 451 "DOW": "DAYOFWEEK", 452 "DW": "DAYOFWEEK", 453 "WEEKDAY_ISO": "DAYOFWEEKISO", 454 "DOW_ISO": "DAYOFWEEKISO", 455 "DW_ISO": "DAYOFWEEKISO", 456 "DAY OF YEAR": "DAYOFYEAR", 457 "DOY": "DAYOFYEAR", 458 "DY": "DAYOFYEAR", 459 "W": "WEEK", 460 "WK": "WEEK", 461 "WEEKOFYEAR": "WEEK", 462 "WOY": "WEEK", 463 "WY": "WEEK", 464 "WEEK_ISO": "WEEKISO", 465 "WEEKOFYEARISO": "WEEKISO", 466 "WEEKOFYEAR_ISO": "WEEKISO", 467 "Q": "QUARTER", 468 "QTR": "QUARTER", 469 "QTRS": "QUARTER", 470 "QUARTERS": "QUARTER", 471 "H": "HOUR", 472 "HH": "HOUR", 473 "HR": "HOUR", 474 "HOURS": "HOUR", 475 "HRS": "HOUR", 476 "M": "MINUTE", 477 "MI": "MINUTE", 478 "MIN": "MINUTE", 479 "MINUTES": "MINUTE", 480 "MINS": "MINUTE", 481 "S": "SECOND", 482 "SEC": "SECOND", 483 "SECONDS": "SECOND", 484 "SECS": "SECOND", 485 "MS": "MILLISECOND", 486 "MSEC": "MILLISECOND", 487 "MSECS": "MILLISECOND", 488 "MSECOND": "MILLISECOND", 489 "MSECONDS": "MILLISECOND", 490 "MILLISEC": "MILLISECOND", 491 "MILLISECS": "MILLISECOND", 492 "MILLISECON": "MILLISECOND", 493 "MILLISECONDS": "MILLISECOND", 494 "US": "MICROSECOND", 495 "USEC": "MICROSECOND", 496 "USECS": "MICROSECOND", 497 "MICROSEC": "MICROSECOND", 498 "MICROSECS": "MICROSECOND", 499 "USECOND": "MICROSECOND", 500 "USECONDS": "MICROSECOND", 501 "MICROSECONDS": "MICROSECOND", 502 "NS": "NANOSECOND", 503 "NSEC": "NANOSECOND", 504 "NANOSEC": "NANOSECOND", 505 "NSECOND": "NANOSECOND", 506 "NSECONDS": "NANOSECOND", 507 "NANOSECS": "NANOSECOND", 508 "EPOCH_SECOND": "EPOCH", 509 "EPOCH_SECONDS": "EPOCH", 510 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 511 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 512 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 513 "TZH": "TIMEZONE_HOUR", 514 "TZM": "TIMEZONE_MINUTE", 515 "DEC": "DECADE", 516 "DECS": "DECADE", 517 "DECADES": "DECADE", 518 "MIL": "MILLENIUM", 519 "MILS": "MILLENIUM", 520 "MILLENIA": "MILLENIUM", 521 "C": "CENTURY", 522 "CENT": "CENTURY", 523 "CENTS": "CENTURY", 524 "CENTURIES": "CENTURY", 525 } 526 527 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 528 exp.DataType.Type.BIGINT: { 529 exp.ApproxDistinct, 530 exp.ArraySize, 531 exp.Count, 532 exp.Length, 533 }, 534 exp.DataType.Type.BOOLEAN: { 535 exp.Between, 536 exp.Boolean, 537 exp.In, 538 exp.RegexpLike, 539 }, 540 exp.DataType.Type.DATE: { 541 exp.CurrentDate, 542 exp.Date, 543 exp.DateFromParts, 544 exp.DateStrToDate, 545 exp.DiToDate, 546 exp.StrToDate, 547 exp.TimeStrToDate, 548 exp.TsOrDsToDate, 549 }, 550 exp.DataType.Type.DATETIME: { 551 exp.CurrentDatetime, 552 exp.Datetime, 553 exp.DatetimeAdd, 554 exp.DatetimeSub, 555 }, 556 exp.DataType.Type.DOUBLE: { 557 exp.ApproxQuantile, 558 exp.Avg, 559 exp.Div, 560 exp.Exp, 561 exp.Ln, 562 exp.Log, 563 exp.Pow, 564 exp.Quantile, 565 exp.Round, 566 exp.SafeDivide, 567 exp.Sqrt, 568 exp.Stddev, 569 exp.StddevPop, 570 exp.StddevSamp, 571 exp.Variance, 572 exp.VariancePop, 573 }, 574 exp.DataType.Type.INT: { 575 exp.Ceil, 576 exp.DatetimeDiff, 577 exp.DateDiff, 578 exp.TimestampDiff, 579 exp.TimeDiff, 580 exp.DateToDi, 581 exp.Levenshtein, 582 exp.Sign, 583 exp.StrPosition, 584 exp.TsOrDiToDi, 585 }, 586 exp.DataType.Type.JSON: { 587 exp.ParseJSON, 588 }, 589 exp.DataType.Type.TIME: { 590 exp.Time, 591 }, 592 exp.DataType.Type.TIMESTAMP: { 593 exp.CurrentTime, 594 exp.CurrentTimestamp, 595 exp.StrToTime, 596 exp.TimeAdd, 597 exp.TimeStrToTime, 598 exp.TimeSub, 599 exp.TimestampAdd, 600 exp.TimestampSub, 601 exp.UnixToTime, 602 }, 603 exp.DataType.Type.TINYINT: { 604 exp.Day, 605 exp.Month, 606 exp.Week, 607 exp.Year, 608 exp.Quarter, 609 }, 610 exp.DataType.Type.VARCHAR: { 611 exp.ArrayConcat, 612 exp.Concat, 613 exp.ConcatWs, 614 exp.DateToDateStr, 615 exp.GroupConcat, 616 exp.Initcap, 617 exp.Lower, 618 exp.Substring, 619 exp.TimeToStr, 620 exp.TimeToTimeStr, 621 exp.Trim, 622 exp.TsOrDsToDateStr, 623 exp.UnixToStr, 624 exp.UnixToTimeStr, 625 exp.Upper, 626 }, 627 } 628 629 ANNOTATORS: AnnotatorsType = { 630 **{ 631 expr_type: lambda self, e: self._annotate_unary(e) 632 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 633 }, 634 **{ 635 expr_type: lambda self, e: self._annotate_binary(e) 636 for expr_type in subclasses(exp.__name__, exp.Binary) 637 }, 638 **{ 639 expr_type: _annotate_with_type_lambda(data_type) 640 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 641 for expr_type in expressions 642 }, 643 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 644 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 645 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 646 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 647 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 648 exp.Bracket: lambda self, e: self._annotate_bracket(e), 649 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 650 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 651 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 652 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 653 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 654 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 655 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 656 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 657 exp.Div: lambda self, e: self._annotate_div(e), 658 exp.Dot: lambda self, e: self._annotate_dot(e), 659 exp.Explode: lambda self, e: self._annotate_explode(e), 660 exp.Extract: lambda self, e: self._annotate_extract(e), 661 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 662 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 663 e, exp.DataType.build("ARRAY<DATE>") 664 ), 665 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 666 e, exp.DataType.build("ARRAY<TIMESTAMP>") 667 ), 668 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 669 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 670 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 671 exp.Literal: lambda self, e: self._annotate_literal(e), 672 exp.Map: lambda self, e: self._annotate_map(e), 673 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 674 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 675 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 676 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 677 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 678 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 679 exp.Struct: lambda self, e: self._annotate_struct(e), 680 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 681 exp.Timestamp: lambda self, e: self._annotate_with_type( 682 e, 683 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 684 ), 685 exp.ToMap: lambda self, e: self._annotate_to_map(e), 686 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 687 exp.Unnest: lambda self, e: self._annotate_unnest(e), 688 exp.VarMap: lambda self, e: self._annotate_map(e), 689 } 690 691 @classmethod 692 def get_or_raise(cls, dialect: DialectType) -> Dialect: 693 """ 694 Look up a dialect in the global dialect registry and return it if it exists. 695 696 Args: 697 dialect: The target dialect. If this is a string, it can be optionally followed by 698 additional key-value pairs that are separated by commas and are used to specify 699 dialect settings, such as whether the dialect's identifiers are case-sensitive. 700 701 Example: 702 >>> dialect = dialect_class = get_or_raise("duckdb") 703 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 704 705 Returns: 706 The corresponding Dialect instance. 707 """ 708 709 if not dialect: 710 return cls() 711 if isinstance(dialect, _Dialect): 712 return dialect() 713 if isinstance(dialect, Dialect): 714 return dialect 715 if isinstance(dialect, str): 716 try: 717 dialect_name, *kv_strings = dialect.split(",") 718 kv_pairs = (kv.split("=") for kv in kv_strings) 719 kwargs = {} 720 for pair in kv_pairs: 721 key = pair[0].strip() 722 value: t.Union[bool | str | None] = None 723 724 if len(pair) == 1: 725 # Default initialize standalone settings to True 726 value = True 727 elif len(pair) == 2: 728 value = pair[1].strip() 729 730 # Coerce the value to boolean if it matches to the truthy/falsy values below 731 value_lower = value.lower() 732 if value_lower in ("true", "1"): 733 value = True 734 elif value_lower in ("false", "0"): 735 value = False 736 737 kwargs[key] = value 738 739 except ValueError: 740 raise ValueError( 741 f"Invalid dialect format: '{dialect}'. " 742 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 743 ) 744 745 result = cls.get(dialect_name.strip()) 746 if not result: 747 from difflib import get_close_matches 748 749 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 750 if similar: 751 similar = f" Did you mean {similar}?" 752 753 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 754 755 return result(**kwargs) 756 757 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 758 759 @classmethod 760 def format_time( 761 cls, expression: t.Optional[str | exp.Expression] 762 ) -> t.Optional[exp.Expression]: 763 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 764 if isinstance(expression, str): 765 return exp.Literal.string( 766 # the time formats are quoted 767 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 768 ) 769 770 if expression and expression.is_string: 771 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 772 773 return expression 774 775 def __init__(self, **kwargs) -> None: 776 normalization_strategy = kwargs.pop("normalization_strategy", None) 777 778 if normalization_strategy is None: 779 self.normalization_strategy = self.NORMALIZATION_STRATEGY 780 else: 781 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 782 783 self.settings = kwargs 784 785 def __eq__(self, other: t.Any) -> bool: 786 # Does not currently take dialect state into account 787 return type(self) == other 788 789 def __hash__(self) -> int: 790 # Does not currently take dialect state into account 791 return hash(type(self)) 792 793 def normalize_identifier(self, expression: E) -> E: 794 """ 795 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 796 797 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 798 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 799 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 800 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 801 802 There are also dialects like Spark, which are case-insensitive even when quotes are 803 present, and dialects like MySQL, whose resolution rules match those employed by the 804 underlying operating system, for example they may always be case-sensitive in Linux. 805 806 Finally, the normalization behavior of some engines can even be controlled through flags, 807 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 808 809 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 810 that it can analyze queries in the optimizer and successfully capture their semantics. 811 """ 812 if ( 813 isinstance(expression, exp.Identifier) 814 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 815 and ( 816 not expression.quoted 817 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 818 ) 819 ): 820 expression.set( 821 "this", 822 ( 823 expression.this.upper() 824 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 825 else expression.this.lower() 826 ), 827 ) 828 829 return expression 830 831 def case_sensitive(self, text: str) -> bool: 832 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 833 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 834 return False 835 836 unsafe = ( 837 str.islower 838 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 839 else str.isupper 840 ) 841 return any(unsafe(char) for char in text) 842 843 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 844 """Checks if text can be identified given an identify option. 845 846 Args: 847 text: The text to check. 848 identify: 849 `"always"` or `True`: Always returns `True`. 850 `"safe"`: Only returns `True` if the identifier is case-insensitive. 851 852 Returns: 853 Whether the given text can be identified. 854 """ 855 if identify is True or identify == "always": 856 return True 857 858 if identify == "safe": 859 return not self.case_sensitive(text) 860 861 return False 862 863 def quote_identifier(self, expression: E, identify: bool = True) -> E: 864 """ 865 Adds quotes to a given identifier. 866 867 Args: 868 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 869 identify: If set to `False`, the quotes will only be added if the identifier is deemed 870 "unsafe", with respect to its characters and this dialect's normalization strategy. 871 """ 872 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 873 name = expression.this 874 expression.set( 875 "quoted", 876 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 877 ) 878 879 return expression 880 881 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 882 if isinstance(path, exp.Literal): 883 path_text = path.name 884 if path.is_number: 885 path_text = f"[{path_text}]" 886 try: 887 return parse_json_path(path_text, self) 888 except ParseError as e: 889 logger.warning(f"Invalid JSON path syntax. {str(e)}") 890 891 return path 892 893 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 894 return self.parser(**opts).parse(self.tokenize(sql), sql) 895 896 def parse_into( 897 self, expression_type: exp.IntoType, sql: str, **opts 898 ) -> t.List[t.Optional[exp.Expression]]: 899 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 900 901 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 902 return self.generator(**opts).generate(expression, copy=copy) 903 904 def transpile(self, sql: str, **opts) -> t.List[str]: 905 return [ 906 self.generate(expression, copy=False, **opts) if expression else "" 907 for expression in self.parse(sql) 908 ] 909 910 def tokenize(self, sql: str) -> t.List[Token]: 911 return self.tokenizer.tokenize(sql) 912 913 @property 914 def tokenizer(self) -> Tokenizer: 915 return self.tokenizer_class(dialect=self) 916 917 @property 918 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 919 return self.jsonpath_tokenizer_class(dialect=self) 920 921 def parser(self, **opts) -> Parser: 922 return self.parser_class(dialect=self, **opts) 923 924 def generator(self, **opts) -> Generator: 925 return self.generator_class(dialect=self, **opts) 926 927 928DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 929 930 931def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 932 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 933 934 935def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 936 if expression.args.get("accuracy"): 937 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 938 return self.func("APPROX_COUNT_DISTINCT", expression.this) 939 940 941def if_sql( 942 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 943) -> t.Callable[[Generator, exp.If], str]: 944 def _if_sql(self: Generator, expression: exp.If) -> str: 945 return self.func( 946 name, 947 expression.this, 948 expression.args.get("true"), 949 expression.args.get("false") or false_value, 950 ) 951 952 return _if_sql 953 954 955def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 956 this = expression.this 957 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 958 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 959 960 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 961 962 963def inline_array_sql(self: Generator, expression: exp.Array) -> str: 964 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 965 966 967def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 968 elem = seq_get(expression.expressions, 0) 969 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 970 return self.func("ARRAY", elem) 971 return inline_array_sql(self, expression) 972 973 974def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 975 return self.like_sql( 976 exp.Like( 977 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 978 ) 979 ) 980 981 982def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 983 zone = self.sql(expression, "this") 984 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 985 986 987def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 988 if expression.args.get("recursive"): 989 self.unsupported("Recursive CTEs are unsupported") 990 expression.args["recursive"] = False 991 return self.with_sql(expression) 992 993 994def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 995 n = self.sql(expression, "this") 996 d = self.sql(expression, "expression") 997 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 998 999 1000def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1001 self.unsupported("TABLESAMPLE unsupported") 1002 return self.sql(expression.this) 1003 1004 1005def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1006 self.unsupported("PIVOT unsupported") 1007 return "" 1008 1009 1010def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1011 return self.cast_sql(expression) 1012 1013 1014def no_comment_column_constraint_sql( 1015 self: Generator, expression: exp.CommentColumnConstraint 1016) -> str: 1017 self.unsupported("CommentColumnConstraint unsupported") 1018 return "" 1019 1020 1021def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1022 self.unsupported("MAP_FROM_ENTRIES unsupported") 1023 return "" 1024 1025 1026def str_position_sql( 1027 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1028) -> str: 1029 this = self.sql(expression, "this") 1030 substr = self.sql(expression, "substr") 1031 position = self.sql(expression, "position") 1032 instance = expression.args.get("instance") if generate_instance else None 1033 position_offset = "" 1034 1035 if position: 1036 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1037 this = self.func("SUBSTR", this, position) 1038 position_offset = f" + {position} - 1" 1039 1040 return self.func("STRPOS", this, substr, instance) + position_offset 1041 1042 1043def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1044 return ( 1045 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1046 ) 1047 1048 1049def var_map_sql( 1050 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1051) -> str: 1052 keys = expression.args["keys"] 1053 values = expression.args["values"] 1054 1055 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1056 self.unsupported("Cannot convert array columns into map.") 1057 return self.func(map_func_name, keys, values) 1058 1059 args = [] 1060 for key, value in zip(keys.expressions, values.expressions): 1061 args.append(self.sql(key)) 1062 args.append(self.sql(value)) 1063 1064 return self.func(map_func_name, *args) 1065 1066 1067def build_formatted_time( 1068 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1069) -> t.Callable[[t.List], E]: 1070 """Helper used for time expressions. 1071 1072 Args: 1073 exp_class: the expression class to instantiate. 1074 dialect: target sql dialect. 1075 default: the default format, True being time. 1076 1077 Returns: 1078 A callable that can be used to return the appropriately formatted time expression. 1079 """ 1080 1081 def _builder(args: t.List): 1082 return exp_class( 1083 this=seq_get(args, 0), 1084 format=Dialect[dialect].format_time( 1085 seq_get(args, 1) 1086 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1087 ), 1088 ) 1089 1090 return _builder 1091 1092 1093def time_format( 1094 dialect: DialectType = None, 1095) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1096 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1097 """ 1098 Returns the time format for a given expression, unless it's equivalent 1099 to the default time format of the dialect of interest. 1100 """ 1101 time_format = self.format_time(expression) 1102 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1103 1104 return _time_format 1105 1106 1107def build_date_delta( 1108 exp_class: t.Type[E], 1109 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1110 default_unit: t.Optional[str] = "DAY", 1111) -> t.Callable[[t.List], E]: 1112 def _builder(args: t.List) -> E: 1113 unit_based = len(args) == 3 1114 this = args[2] if unit_based else seq_get(args, 0) 1115 unit = None 1116 if unit_based or default_unit: 1117 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1118 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1119 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1120 1121 return _builder 1122 1123 1124def build_date_delta_with_interval( 1125 expression_class: t.Type[E], 1126) -> t.Callable[[t.List], t.Optional[E]]: 1127 def _builder(args: t.List) -> t.Optional[E]: 1128 if len(args) < 2: 1129 return None 1130 1131 interval = args[1] 1132 1133 if not isinstance(interval, exp.Interval): 1134 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1135 1136 expression = interval.this 1137 if expression and expression.is_string: 1138 expression = exp.Literal.number(expression.this) 1139 1140 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1141 1142 return _builder 1143 1144 1145def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1146 unit = seq_get(args, 0) 1147 this = seq_get(args, 1) 1148 1149 if isinstance(this, exp.Cast) and this.is_type("date"): 1150 return exp.DateTrunc(unit=unit, this=this) 1151 return exp.TimestampTrunc(this=this, unit=unit) 1152 1153 1154def date_add_interval_sql( 1155 data_type: str, kind: str 1156) -> t.Callable[[Generator, exp.Expression], str]: 1157 def func(self: Generator, expression: exp.Expression) -> str: 1158 this = self.sql(expression, "this") 1159 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1160 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1161 1162 return func 1163 1164 1165def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1166 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1167 args = [unit_to_str(expression), expression.this] 1168 if zone: 1169 args.append(expression.args.get("zone")) 1170 return self.func("DATE_TRUNC", *args) 1171 1172 return _timestamptrunc_sql 1173 1174 1175def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1176 zone = expression.args.get("zone") 1177 if not zone: 1178 from sqlglot.optimizer.annotate_types import annotate_types 1179 1180 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1181 return self.sql(exp.cast(expression.this, target_type)) 1182 if zone.name.lower() in TIMEZONES: 1183 return self.sql( 1184 exp.AtTimeZone( 1185 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1186 zone=zone, 1187 ) 1188 ) 1189 return self.func("TIMESTAMP", expression.this, zone) 1190 1191 1192def no_time_sql(self: Generator, expression: exp.Time) -> str: 1193 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1194 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1195 expr = exp.cast( 1196 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1197 ) 1198 return self.sql(expr) 1199 1200 1201def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1202 this = expression.this 1203 expr = expression.expression 1204 1205 if expr.name.lower() in TIMEZONES: 1206 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1207 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1208 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1209 return self.sql(this) 1210 1211 this = exp.cast(this, exp.DataType.Type.DATE) 1212 expr = exp.cast(expr, exp.DataType.Type.TIME) 1213 1214 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1215 1216 1217def locate_to_strposition(args: t.List) -> exp.Expression: 1218 return exp.StrPosition( 1219 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1220 ) 1221 1222 1223def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1224 return self.func( 1225 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1226 ) 1227 1228 1229def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1230 return self.sql( 1231 exp.Substring( 1232 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1233 ) 1234 ) 1235 1236 1237def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1238 return self.sql( 1239 exp.Substring( 1240 this=expression.this, 1241 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1242 ) 1243 ) 1244 1245 1246def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1247 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1248 1249 1250def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1251 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1252 1253 1254# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1255def encode_decode_sql( 1256 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1257) -> str: 1258 charset = expression.args.get("charset") 1259 if charset and charset.name.lower() != "utf-8": 1260 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1261 1262 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1263 1264 1265def min_or_least(self: Generator, expression: exp.Min) -> str: 1266 name = "LEAST" if expression.expressions else "MIN" 1267 return rename_func(name)(self, expression) 1268 1269 1270def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1271 name = "GREATEST" if expression.expressions else "MAX" 1272 return rename_func(name)(self, expression) 1273 1274 1275def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1276 cond = expression.this 1277 1278 if isinstance(expression.this, exp.Distinct): 1279 cond = expression.this.expressions[0] 1280 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1281 1282 return self.func("sum", exp.func("if", cond, 1, 0)) 1283 1284 1285def trim_sql(self: Generator, expression: exp.Trim) -> str: 1286 target = self.sql(expression, "this") 1287 trim_type = self.sql(expression, "position") 1288 remove_chars = self.sql(expression, "expression") 1289 collation = self.sql(expression, "collation") 1290 1291 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1292 if not remove_chars and not collation: 1293 return self.trim_sql(expression) 1294 1295 trim_type = f"{trim_type} " if trim_type else "" 1296 remove_chars = f"{remove_chars} " if remove_chars else "" 1297 from_part = "FROM " if trim_type or remove_chars else "" 1298 collation = f" COLLATE {collation}" if collation else "" 1299 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1300 1301 1302def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1303 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1304 1305 1306def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1307 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1308 1309 1310def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1311 delim, *rest_args = expression.expressions 1312 return self.sql( 1313 reduce( 1314 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1315 rest_args, 1316 ) 1317 ) 1318 1319 1320def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1321 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1322 if bad_args: 1323 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1324 1325 return self.func( 1326 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1327 ) 1328 1329 1330def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1331 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1332 if bad_args: 1333 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1334 1335 return self.func( 1336 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1337 ) 1338 1339 1340def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1341 names = [] 1342 for agg in aggregations: 1343 if isinstance(agg, exp.Alias): 1344 names.append(agg.alias) 1345 else: 1346 """ 1347 This case corresponds to aggregations without aliases being used as suffixes 1348 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1349 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1350 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1351 """ 1352 agg_all_unquoted = agg.transform( 1353 lambda node: ( 1354 exp.Identifier(this=node.name, quoted=False) 1355 if isinstance(node, exp.Identifier) 1356 else node 1357 ) 1358 ) 1359 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1360 1361 return names 1362 1363 1364def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1365 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1366 1367 1368# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1369def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1370 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1371 1372 1373def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1374 return self.func("MAX", expression.this) 1375 1376 1377def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1378 a = self.sql(expression.left) 1379 b = self.sql(expression.right) 1380 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1381 1382 1383def is_parse_json(expression: exp.Expression) -> bool: 1384 return isinstance(expression, exp.ParseJSON) or ( 1385 isinstance(expression, exp.Cast) and expression.is_type("json") 1386 ) 1387 1388 1389def isnull_to_is_null(args: t.List) -> exp.Expression: 1390 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1391 1392 1393def generatedasidentitycolumnconstraint_sql( 1394 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1395) -> str: 1396 start = self.sql(expression, "start") or "1" 1397 increment = self.sql(expression, "increment") or "1" 1398 return f"IDENTITY({start}, {increment})" 1399 1400 1401def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1402 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1403 if expression.args.get("count"): 1404 self.unsupported(f"Only two arguments are supported in function {name}.") 1405 1406 return self.func(name, expression.this, expression.expression) 1407 1408 return _arg_max_or_min_sql 1409 1410 1411def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1412 this = expression.this.copy() 1413 1414 return_type = expression.return_type 1415 if return_type.is_type(exp.DataType.Type.DATE): 1416 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1417 # can truncate timestamp strings, because some dialects can't cast them to DATE 1418 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1419 1420 expression.this.replace(exp.cast(this, return_type)) 1421 return expression 1422 1423 1424def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1425 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1426 if cast and isinstance(expression, exp.TsOrDsAdd): 1427 expression = ts_or_ds_add_cast(expression) 1428 1429 return self.func( 1430 name, 1431 unit_to_var(expression), 1432 expression.expression, 1433 expression.this, 1434 ) 1435 1436 return _delta_sql 1437 1438 1439def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1440 unit = expression.args.get("unit") 1441 1442 if isinstance(unit, exp.Placeholder): 1443 return unit 1444 if unit: 1445 return exp.Literal.string(unit.name) 1446 return exp.Literal.string(default) if default else None 1447 1448 1449def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1450 unit = expression.args.get("unit") 1451 1452 if isinstance(unit, (exp.Var, exp.Placeholder)): 1453 return unit 1454 return exp.Var(this=default) if default else None 1455 1456 1457@t.overload 1458def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1459 pass 1460 1461 1462@t.overload 1463def map_date_part( 1464 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1465) -> t.Optional[exp.Expression]: 1466 pass 1467 1468 1469def map_date_part(part, dialect: DialectType = Dialect): 1470 mapped = ( 1471 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1472 ) 1473 return exp.var(mapped) if mapped else part 1474 1475 1476def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1477 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1478 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1479 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1480 1481 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1482 1483 1484def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1485 """Remove table refs from columns in when statements.""" 1486 alias = expression.this.args.get("alias") 1487 1488 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1489 return self.dialect.normalize_identifier(identifier).name if identifier else None 1490 1491 targets = {normalize(expression.this.this)} 1492 1493 if alias: 1494 targets.add(normalize(alias.this)) 1495 1496 for when in expression.expressions: 1497 when.transform( 1498 lambda node: ( 1499 exp.column(node.this) 1500 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1501 else node 1502 ), 1503 copy=False, 1504 ) 1505 1506 return self.merge_sql(expression) 1507 1508 1509def build_json_extract_path( 1510 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1511) -> t.Callable[[t.List], F]: 1512 def _builder(args: t.List) -> F: 1513 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1514 for arg in args[1:]: 1515 if not isinstance(arg, exp.Literal): 1516 # We use the fallback parser because we can't really transpile non-literals safely 1517 return expr_type.from_arg_list(args) 1518 1519 text = arg.name 1520 if is_int(text): 1521 index = int(text) 1522 segments.append( 1523 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1524 ) 1525 else: 1526 segments.append(exp.JSONPathKey(this=text)) 1527 1528 # This is done to avoid failing in the expression validator due to the arg count 1529 del args[2:] 1530 return expr_type( 1531 this=seq_get(args, 0), 1532 expression=exp.JSONPath(expressions=segments), 1533 only_json_types=arrow_req_json_type, 1534 ) 1535 1536 return _builder 1537 1538 1539def json_extract_segments( 1540 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1541) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1542 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1543 path = expression.expression 1544 if not isinstance(path, exp.JSONPath): 1545 return rename_func(name)(self, expression) 1546 1547 segments = [] 1548 for segment in path.expressions: 1549 path = self.sql(segment) 1550 if path: 1551 if isinstance(segment, exp.JSONPathPart) and ( 1552 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1553 ): 1554 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1555 1556 segments.append(path) 1557 1558 if op: 1559 return f" {op} ".join([self.sql(expression.this), *segments]) 1560 return self.func(name, expression.this, *segments) 1561 1562 return _json_extract_segments 1563 1564 1565def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1566 if isinstance(expression.this, exp.JSONPathWildcard): 1567 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1568 1569 return expression.name 1570 1571 1572def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1573 cond = expression.expression 1574 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1575 alias = cond.expressions[0] 1576 cond = cond.this 1577 elif isinstance(cond, exp.Predicate): 1578 alias = "_u" 1579 else: 1580 self.unsupported("Unsupported filter condition") 1581 return "" 1582 1583 unnest = exp.Unnest(expressions=[expression.this]) 1584 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1585 return self.sql(exp.Array(expressions=[filtered])) 1586 1587 1588def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1589 return self.func( 1590 "TO_NUMBER", 1591 expression.this, 1592 expression.args.get("format"), 1593 expression.args.get("nlsparam"), 1594 ) 1595 1596 1597def build_default_decimal_type( 1598 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1599) -> t.Callable[[exp.DataType], exp.DataType]: 1600 def _builder(dtype: exp.DataType) -> exp.DataType: 1601 if dtype.expressions or precision is None: 1602 return dtype 1603 1604 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1605 return exp.DataType.build(f"DECIMAL({params})") 1606 1607 return _builder 1608 1609 1610def build_timestamp_from_parts(args: t.List) -> exp.Func: 1611 if len(args) == 2: 1612 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1613 # so we parse this into Anonymous for now instead of introducing complexity 1614 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1615 1616 return exp.TimestampFromParts.from_arg_list(args) 1617 1618 1619def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1620 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1621 1622 1623def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1624 start = expression.args.get("start") 1625 end = expression.args.get("end") 1626 step = expression.args.get("step") 1627 1628 if isinstance(start, exp.Cast): 1629 target_type = start.to 1630 elif isinstance(end, exp.Cast): 1631 target_type = end.to 1632 else: 1633 target_type = None 1634 1635 if start and end and target_type and target_type.is_type("date", "timestamp"): 1636 if isinstance(start, exp.Cast) and target_type is start.to: 1637 end = exp.cast(end, target_type) 1638 else: 1639 start = exp.cast(start, target_type) 1640 1641 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
222class Dialect(metaclass=_Dialect): 223 INDEX_OFFSET = 0 224 """The base index offset for arrays.""" 225 226 WEEK_OFFSET = 0 227 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 228 229 UNNEST_COLUMN_ONLY = False 230 """Whether `UNNEST` table aliases are treated as column aliases.""" 231 232 ALIAS_POST_TABLESAMPLE = False 233 """Whether the table alias comes after tablesample.""" 234 235 TABLESAMPLE_SIZE_IS_PERCENT = False 236 """Whether a size in the table sample clause represents percentage.""" 237 238 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 239 """Specifies the strategy according to which identifiers should be normalized.""" 240 241 IDENTIFIERS_CAN_START_WITH_DIGIT = False 242 """Whether an unquoted identifier can start with a digit.""" 243 244 DPIPE_IS_STRING_CONCAT = True 245 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 246 247 STRICT_STRING_CONCAT = False 248 """Whether `CONCAT`'s arguments must be strings.""" 249 250 SUPPORTS_USER_DEFINED_TYPES = True 251 """Whether user-defined data types are supported.""" 252 253 SUPPORTS_SEMI_ANTI_JOIN = True 254 """Whether `SEMI` or `ANTI` joins are supported.""" 255 256 SUPPORTS_COLUMN_JOIN_MARKS = False 257 """Whether the old-style outer join (+) syntax is supported.""" 258 259 COPY_PARAMS_ARE_CSV = True 260 """Separator of COPY statement parameters.""" 261 262 NORMALIZE_FUNCTIONS: bool | str = "upper" 263 """ 264 Determines how function names are going to be normalized. 265 Possible values: 266 "upper" or True: Convert names to uppercase. 267 "lower": Convert names to lowercase. 268 False: Disables function name normalization. 269 """ 270 271 LOG_BASE_FIRST: t.Optional[bool] = True 272 """ 273 Whether the base comes first in the `LOG` function. 274 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 275 """ 276 277 NULL_ORDERING = "nulls_are_small" 278 """ 279 Default `NULL` ordering method to use if not explicitly set. 280 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 281 """ 282 283 TYPED_DIVISION = False 284 """ 285 Whether the behavior of `a / b` depends on the types of `a` and `b`. 286 False means `a / b` is always float division. 287 True means `a / b` is integer division if both `a` and `b` are integers. 288 """ 289 290 SAFE_DIVISION = False 291 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 292 293 CONCAT_COALESCE = False 294 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 295 296 HEX_LOWERCASE = False 297 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 298 299 DATE_FORMAT = "'%Y-%m-%d'" 300 DATEINT_FORMAT = "'%Y%m%d'" 301 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 302 303 TIME_MAPPING: t.Dict[str, str] = {} 304 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 305 306 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 307 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 308 FORMAT_MAPPING: t.Dict[str, str] = {} 309 """ 310 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 311 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 312 """ 313 314 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 315 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 316 317 PSEUDOCOLUMNS: t.Set[str] = set() 318 """ 319 Columns that are auto-generated by the engine corresponding to this dialect. 320 For example, such columns may be excluded from `SELECT *` queries. 321 """ 322 323 PREFER_CTE_ALIAS_COLUMN = False 324 """ 325 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 326 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 327 any projection aliases in the subquery. 328 329 For example, 330 WITH y(c) AS ( 331 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 332 ) SELECT c FROM y; 333 334 will be rewritten as 335 336 WITH y(c) AS ( 337 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 338 ) SELECT c FROM y; 339 """ 340 341 COPY_PARAMS_ARE_CSV = True 342 """ 343 Whether COPY statement parameters are separated by comma or whitespace 344 """ 345 346 FORCE_EARLY_ALIAS_REF_EXPANSION = False 347 """ 348 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 349 350 For example: 351 WITH data AS ( 352 SELECT 353 1 AS id, 354 2 AS my_id 355 ) 356 SELECT 357 id AS my_id 358 FROM 359 data 360 WHERE 361 my_id = 1 362 GROUP BY 363 my_id, 364 HAVING 365 my_id = 1 366 367 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 368 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 369 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 370 """ 371 372 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 373 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 374 375 SUPPORTS_ORDER_BY_ALL = False 376 """ 377 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 378 """ 379 380 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 381 """ 382 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 383 as the former is of type INT[] vs the latter which is SUPER 384 """ 385 386 SUPPORTS_FIXED_SIZE_ARRAYS = False 387 """ 388 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 389 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 390 """ 391 392 CREATABLE_KIND_MAPPING: dict[str, str] = {} 393 """ 394 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 395 equivalent of CREATE SCHEMA is CREATE DATABASE. 396 """ 397 398 # --- Autofilled --- 399 400 tokenizer_class = Tokenizer 401 jsonpath_tokenizer_class = JSONPathTokenizer 402 parser_class = Parser 403 generator_class = Generator 404 405 # A trie of the time_mapping keys 406 TIME_TRIE: t.Dict = {} 407 FORMAT_TRIE: t.Dict = {} 408 409 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 410 INVERSE_TIME_TRIE: t.Dict = {} 411 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 412 INVERSE_FORMAT_TRIE: t.Dict = {} 413 414 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 415 416 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 417 418 # Delimiters for string literals and identifiers 419 QUOTE_START = "'" 420 QUOTE_END = "'" 421 IDENTIFIER_START = '"' 422 IDENTIFIER_END = '"' 423 424 # Delimiters for bit, hex, byte and unicode literals 425 BIT_START: t.Optional[str] = None 426 BIT_END: t.Optional[str] = None 427 HEX_START: t.Optional[str] = None 428 HEX_END: t.Optional[str] = None 429 BYTE_START: t.Optional[str] = None 430 BYTE_END: t.Optional[str] = None 431 UNICODE_START: t.Optional[str] = None 432 UNICODE_END: t.Optional[str] = None 433 434 DATE_PART_MAPPING = { 435 "Y": "YEAR", 436 "YY": "YEAR", 437 "YYY": "YEAR", 438 "YYYY": "YEAR", 439 "YR": "YEAR", 440 "YEARS": "YEAR", 441 "YRS": "YEAR", 442 "MM": "MONTH", 443 "MON": "MONTH", 444 "MONS": "MONTH", 445 "MONTHS": "MONTH", 446 "D": "DAY", 447 "DD": "DAY", 448 "DAYS": "DAY", 449 "DAYOFMONTH": "DAY", 450 "DAY OF WEEK": "DAYOFWEEK", 451 "WEEKDAY": "DAYOFWEEK", 452 "DOW": "DAYOFWEEK", 453 "DW": "DAYOFWEEK", 454 "WEEKDAY_ISO": "DAYOFWEEKISO", 455 "DOW_ISO": "DAYOFWEEKISO", 456 "DW_ISO": "DAYOFWEEKISO", 457 "DAY OF YEAR": "DAYOFYEAR", 458 "DOY": "DAYOFYEAR", 459 "DY": "DAYOFYEAR", 460 "W": "WEEK", 461 "WK": "WEEK", 462 "WEEKOFYEAR": "WEEK", 463 "WOY": "WEEK", 464 "WY": "WEEK", 465 "WEEK_ISO": "WEEKISO", 466 "WEEKOFYEARISO": "WEEKISO", 467 "WEEKOFYEAR_ISO": "WEEKISO", 468 "Q": "QUARTER", 469 "QTR": "QUARTER", 470 "QTRS": "QUARTER", 471 "QUARTERS": "QUARTER", 472 "H": "HOUR", 473 "HH": "HOUR", 474 "HR": "HOUR", 475 "HOURS": "HOUR", 476 "HRS": "HOUR", 477 "M": "MINUTE", 478 "MI": "MINUTE", 479 "MIN": "MINUTE", 480 "MINUTES": "MINUTE", 481 "MINS": "MINUTE", 482 "S": "SECOND", 483 "SEC": "SECOND", 484 "SECONDS": "SECOND", 485 "SECS": "SECOND", 486 "MS": "MILLISECOND", 487 "MSEC": "MILLISECOND", 488 "MSECS": "MILLISECOND", 489 "MSECOND": "MILLISECOND", 490 "MSECONDS": "MILLISECOND", 491 "MILLISEC": "MILLISECOND", 492 "MILLISECS": "MILLISECOND", 493 "MILLISECON": "MILLISECOND", 494 "MILLISECONDS": "MILLISECOND", 495 "US": "MICROSECOND", 496 "USEC": "MICROSECOND", 497 "USECS": "MICROSECOND", 498 "MICROSEC": "MICROSECOND", 499 "MICROSECS": "MICROSECOND", 500 "USECOND": "MICROSECOND", 501 "USECONDS": "MICROSECOND", 502 "MICROSECONDS": "MICROSECOND", 503 "NS": "NANOSECOND", 504 "NSEC": "NANOSECOND", 505 "NANOSEC": "NANOSECOND", 506 "NSECOND": "NANOSECOND", 507 "NSECONDS": "NANOSECOND", 508 "NANOSECS": "NANOSECOND", 509 "EPOCH_SECOND": "EPOCH", 510 "EPOCH_SECONDS": "EPOCH", 511 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 512 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 513 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 514 "TZH": "TIMEZONE_HOUR", 515 "TZM": "TIMEZONE_MINUTE", 516 "DEC": "DECADE", 517 "DECS": "DECADE", 518 "DECADES": "DECADE", 519 "MIL": "MILLENIUM", 520 "MILS": "MILLENIUM", 521 "MILLENIA": "MILLENIUM", 522 "C": "CENTURY", 523 "CENT": "CENTURY", 524 "CENTS": "CENTURY", 525 "CENTURIES": "CENTURY", 526 } 527 528 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 529 exp.DataType.Type.BIGINT: { 530 exp.ApproxDistinct, 531 exp.ArraySize, 532 exp.Count, 533 exp.Length, 534 }, 535 exp.DataType.Type.BOOLEAN: { 536 exp.Between, 537 exp.Boolean, 538 exp.In, 539 exp.RegexpLike, 540 }, 541 exp.DataType.Type.DATE: { 542 exp.CurrentDate, 543 exp.Date, 544 exp.DateFromParts, 545 exp.DateStrToDate, 546 exp.DiToDate, 547 exp.StrToDate, 548 exp.TimeStrToDate, 549 exp.TsOrDsToDate, 550 }, 551 exp.DataType.Type.DATETIME: { 552 exp.CurrentDatetime, 553 exp.Datetime, 554 exp.DatetimeAdd, 555 exp.DatetimeSub, 556 }, 557 exp.DataType.Type.DOUBLE: { 558 exp.ApproxQuantile, 559 exp.Avg, 560 exp.Div, 561 exp.Exp, 562 exp.Ln, 563 exp.Log, 564 exp.Pow, 565 exp.Quantile, 566 exp.Round, 567 exp.SafeDivide, 568 exp.Sqrt, 569 exp.Stddev, 570 exp.StddevPop, 571 exp.StddevSamp, 572 exp.Variance, 573 exp.VariancePop, 574 }, 575 exp.DataType.Type.INT: { 576 exp.Ceil, 577 exp.DatetimeDiff, 578 exp.DateDiff, 579 exp.TimestampDiff, 580 exp.TimeDiff, 581 exp.DateToDi, 582 exp.Levenshtein, 583 exp.Sign, 584 exp.StrPosition, 585 exp.TsOrDiToDi, 586 }, 587 exp.DataType.Type.JSON: { 588 exp.ParseJSON, 589 }, 590 exp.DataType.Type.TIME: { 591 exp.Time, 592 }, 593 exp.DataType.Type.TIMESTAMP: { 594 exp.CurrentTime, 595 exp.CurrentTimestamp, 596 exp.StrToTime, 597 exp.TimeAdd, 598 exp.TimeStrToTime, 599 exp.TimeSub, 600 exp.TimestampAdd, 601 exp.TimestampSub, 602 exp.UnixToTime, 603 }, 604 exp.DataType.Type.TINYINT: { 605 exp.Day, 606 exp.Month, 607 exp.Week, 608 exp.Year, 609 exp.Quarter, 610 }, 611 exp.DataType.Type.VARCHAR: { 612 exp.ArrayConcat, 613 exp.Concat, 614 exp.ConcatWs, 615 exp.DateToDateStr, 616 exp.GroupConcat, 617 exp.Initcap, 618 exp.Lower, 619 exp.Substring, 620 exp.TimeToStr, 621 exp.TimeToTimeStr, 622 exp.Trim, 623 exp.TsOrDsToDateStr, 624 exp.UnixToStr, 625 exp.UnixToTimeStr, 626 exp.Upper, 627 }, 628 } 629 630 ANNOTATORS: AnnotatorsType = { 631 **{ 632 expr_type: lambda self, e: self._annotate_unary(e) 633 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 634 }, 635 **{ 636 expr_type: lambda self, e: self._annotate_binary(e) 637 for expr_type in subclasses(exp.__name__, exp.Binary) 638 }, 639 **{ 640 expr_type: _annotate_with_type_lambda(data_type) 641 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 642 for expr_type in expressions 643 }, 644 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 645 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 646 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 647 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 648 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 649 exp.Bracket: lambda self, e: self._annotate_bracket(e), 650 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 651 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 652 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 653 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 654 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 655 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 656 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 657 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 658 exp.Div: lambda self, e: self._annotate_div(e), 659 exp.Dot: lambda self, e: self._annotate_dot(e), 660 exp.Explode: lambda self, e: self._annotate_explode(e), 661 exp.Extract: lambda self, e: self._annotate_extract(e), 662 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 663 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 664 e, exp.DataType.build("ARRAY<DATE>") 665 ), 666 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 667 e, exp.DataType.build("ARRAY<TIMESTAMP>") 668 ), 669 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 670 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 671 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 672 exp.Literal: lambda self, e: self._annotate_literal(e), 673 exp.Map: lambda self, e: self._annotate_map(e), 674 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 675 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 676 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 677 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 678 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 679 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 680 exp.Struct: lambda self, e: self._annotate_struct(e), 681 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 682 exp.Timestamp: lambda self, e: self._annotate_with_type( 683 e, 684 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 685 ), 686 exp.ToMap: lambda self, e: self._annotate_to_map(e), 687 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 688 exp.Unnest: lambda self, e: self._annotate_unnest(e), 689 exp.VarMap: lambda self, e: self._annotate_map(e), 690 } 691 692 @classmethod 693 def get_or_raise(cls, dialect: DialectType) -> Dialect: 694 """ 695 Look up a dialect in the global dialect registry and return it if it exists. 696 697 Args: 698 dialect: The target dialect. If this is a string, it can be optionally followed by 699 additional key-value pairs that are separated by commas and are used to specify 700 dialect settings, such as whether the dialect's identifiers are case-sensitive. 701 702 Example: 703 >>> dialect = dialect_class = get_or_raise("duckdb") 704 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 705 706 Returns: 707 The corresponding Dialect instance. 708 """ 709 710 if not dialect: 711 return cls() 712 if isinstance(dialect, _Dialect): 713 return dialect() 714 if isinstance(dialect, Dialect): 715 return dialect 716 if isinstance(dialect, str): 717 try: 718 dialect_name, *kv_strings = dialect.split(",") 719 kv_pairs = (kv.split("=") for kv in kv_strings) 720 kwargs = {} 721 for pair in kv_pairs: 722 key = pair[0].strip() 723 value: t.Union[bool | str | None] = None 724 725 if len(pair) == 1: 726 # Default initialize standalone settings to True 727 value = True 728 elif len(pair) == 2: 729 value = pair[1].strip() 730 731 # Coerce the value to boolean if it matches to the truthy/falsy values below 732 value_lower = value.lower() 733 if value_lower in ("true", "1"): 734 value = True 735 elif value_lower in ("false", "0"): 736 value = False 737 738 kwargs[key] = value 739 740 except ValueError: 741 raise ValueError( 742 f"Invalid dialect format: '{dialect}'. " 743 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 744 ) 745 746 result = cls.get(dialect_name.strip()) 747 if not result: 748 from difflib import get_close_matches 749 750 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 751 if similar: 752 similar = f" Did you mean {similar}?" 753 754 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 755 756 return result(**kwargs) 757 758 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 759 760 @classmethod 761 def format_time( 762 cls, expression: t.Optional[str | exp.Expression] 763 ) -> t.Optional[exp.Expression]: 764 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 765 if isinstance(expression, str): 766 return exp.Literal.string( 767 # the time formats are quoted 768 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 769 ) 770 771 if expression and expression.is_string: 772 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 773 774 return expression 775 776 def __init__(self, **kwargs) -> None: 777 normalization_strategy = kwargs.pop("normalization_strategy", None) 778 779 if normalization_strategy is None: 780 self.normalization_strategy = self.NORMALIZATION_STRATEGY 781 else: 782 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 783 784 self.settings = kwargs 785 786 def __eq__(self, other: t.Any) -> bool: 787 # Does not currently take dialect state into account 788 return type(self) == other 789 790 def __hash__(self) -> int: 791 # Does not currently take dialect state into account 792 return hash(type(self)) 793 794 def normalize_identifier(self, expression: E) -> E: 795 """ 796 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 797 798 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 799 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 800 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 801 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 802 803 There are also dialects like Spark, which are case-insensitive even when quotes are 804 present, and dialects like MySQL, whose resolution rules match those employed by the 805 underlying operating system, for example they may always be case-sensitive in Linux. 806 807 Finally, the normalization behavior of some engines can even be controlled through flags, 808 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 809 810 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 811 that it can analyze queries in the optimizer and successfully capture their semantics. 812 """ 813 if ( 814 isinstance(expression, exp.Identifier) 815 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 816 and ( 817 not expression.quoted 818 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 819 ) 820 ): 821 expression.set( 822 "this", 823 ( 824 expression.this.upper() 825 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 826 else expression.this.lower() 827 ), 828 ) 829 830 return expression 831 832 def case_sensitive(self, text: str) -> bool: 833 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 834 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 835 return False 836 837 unsafe = ( 838 str.islower 839 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 840 else str.isupper 841 ) 842 return any(unsafe(char) for char in text) 843 844 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 845 """Checks if text can be identified given an identify option. 846 847 Args: 848 text: The text to check. 849 identify: 850 `"always"` or `True`: Always returns `True`. 851 `"safe"`: Only returns `True` if the identifier is case-insensitive. 852 853 Returns: 854 Whether the given text can be identified. 855 """ 856 if identify is True or identify == "always": 857 return True 858 859 if identify == "safe": 860 return not self.case_sensitive(text) 861 862 return False 863 864 def quote_identifier(self, expression: E, identify: bool = True) -> E: 865 """ 866 Adds quotes to a given identifier. 867 868 Args: 869 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 870 identify: If set to `False`, the quotes will only be added if the identifier is deemed 871 "unsafe", with respect to its characters and this dialect's normalization strategy. 872 """ 873 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 874 name = expression.this 875 expression.set( 876 "quoted", 877 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 878 ) 879 880 return expression 881 882 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 883 if isinstance(path, exp.Literal): 884 path_text = path.name 885 if path.is_number: 886 path_text = f"[{path_text}]" 887 try: 888 return parse_json_path(path_text, self) 889 except ParseError as e: 890 logger.warning(f"Invalid JSON path syntax. {str(e)}") 891 892 return path 893 894 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 895 return self.parser(**opts).parse(self.tokenize(sql), sql) 896 897 def parse_into( 898 self, expression_type: exp.IntoType, sql: str, **opts 899 ) -> t.List[t.Optional[exp.Expression]]: 900 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 901 902 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 903 return self.generator(**opts).generate(expression, copy=copy) 904 905 def transpile(self, sql: str, **opts) -> t.List[str]: 906 return [ 907 self.generate(expression, copy=False, **opts) if expression else "" 908 for expression in self.parse(sql) 909 ] 910 911 def tokenize(self, sql: str) -> t.List[Token]: 912 return self.tokenizer.tokenize(sql) 913 914 @property 915 def tokenizer(self) -> Tokenizer: 916 return self.tokenizer_class(dialect=self) 917 918 @property 919 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 920 return self.jsonpath_tokenizer_class(dialect=self) 921 922 def parser(self, **opts) -> Parser: 923 return self.parser_class(dialect=self, **opts) 924 925 def generator(self, **opts) -> Generator: 926 return self.generator_class(dialect=self, **opts)
776 def __init__(self, **kwargs) -> None: 777 normalization_strategy = kwargs.pop("normalization_strategy", None) 778 779 if normalization_strategy is None: 780 self.normalization_strategy = self.NORMALIZATION_STRATEGY 781 else: 782 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 783 784 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
692 @classmethod 693 def get_or_raise(cls, dialect: DialectType) -> Dialect: 694 """ 695 Look up a dialect in the global dialect registry and return it if it exists. 696 697 Args: 698 dialect: The target dialect. If this is a string, it can be optionally followed by 699 additional key-value pairs that are separated by commas and are used to specify 700 dialect settings, such as whether the dialect's identifiers are case-sensitive. 701 702 Example: 703 >>> dialect = dialect_class = get_or_raise("duckdb") 704 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 705 706 Returns: 707 The corresponding Dialect instance. 708 """ 709 710 if not dialect: 711 return cls() 712 if isinstance(dialect, _Dialect): 713 return dialect() 714 if isinstance(dialect, Dialect): 715 return dialect 716 if isinstance(dialect, str): 717 try: 718 dialect_name, *kv_strings = dialect.split(",") 719 kv_pairs = (kv.split("=") for kv in kv_strings) 720 kwargs = {} 721 for pair in kv_pairs: 722 key = pair[0].strip() 723 value: t.Union[bool | str | None] = None 724 725 if len(pair) == 1: 726 # Default initialize standalone settings to True 727 value = True 728 elif len(pair) == 2: 729 value = pair[1].strip() 730 731 # Coerce the value to boolean if it matches to the truthy/falsy values below 732 value_lower = value.lower() 733 if value_lower in ("true", "1"): 734 value = True 735 elif value_lower in ("false", "0"): 736 value = False 737 738 kwargs[key] = value 739 740 except ValueError: 741 raise ValueError( 742 f"Invalid dialect format: '{dialect}'. " 743 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 744 ) 745 746 result = cls.get(dialect_name.strip()) 747 if not result: 748 from difflib import get_close_matches 749 750 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 751 if similar: 752 similar = f" Did you mean {similar}?" 753 754 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 755 756 return result(**kwargs) 757 758 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
760 @classmethod 761 def format_time( 762 cls, expression: t.Optional[str | exp.Expression] 763 ) -> t.Optional[exp.Expression]: 764 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 765 if isinstance(expression, str): 766 return exp.Literal.string( 767 # the time formats are quoted 768 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 769 ) 770 771 if expression and expression.is_string: 772 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 773 774 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
794 def normalize_identifier(self, expression: E) -> E: 795 """ 796 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 797 798 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 799 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 800 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 801 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 802 803 There are also dialects like Spark, which are case-insensitive even when quotes are 804 present, and dialects like MySQL, whose resolution rules match those employed by the 805 underlying operating system, for example they may always be case-sensitive in Linux. 806 807 Finally, the normalization behavior of some engines can even be controlled through flags, 808 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 809 810 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 811 that it can analyze queries in the optimizer and successfully capture their semantics. 812 """ 813 if ( 814 isinstance(expression, exp.Identifier) 815 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 816 and ( 817 not expression.quoted 818 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 819 ) 820 ): 821 expression.set( 822 "this", 823 ( 824 expression.this.upper() 825 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 826 else expression.this.lower() 827 ), 828 ) 829 830 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
832 def case_sensitive(self, text: str) -> bool: 833 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 834 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 835 return False 836 837 unsafe = ( 838 str.islower 839 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 840 else str.isupper 841 ) 842 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
844 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 845 """Checks if text can be identified given an identify option. 846 847 Args: 848 text: The text to check. 849 identify: 850 `"always"` or `True`: Always returns `True`. 851 `"safe"`: Only returns `True` if the identifier is case-insensitive. 852 853 Returns: 854 Whether the given text can be identified. 855 """ 856 if identify is True or identify == "always": 857 return True 858 859 if identify == "safe": 860 return not self.case_sensitive(text) 861 862 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
864 def quote_identifier(self, expression: E, identify: bool = True) -> E: 865 """ 866 Adds quotes to a given identifier. 867 868 Args: 869 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 870 identify: If set to `False`, the quotes will only be added if the identifier is deemed 871 "unsafe", with respect to its characters and this dialect's normalization strategy. 872 """ 873 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 874 name = expression.this 875 expression.set( 876 "quoted", 877 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 878 ) 879 880 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
882 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 883 if isinstance(path, exp.Literal): 884 path_text = path.name 885 if path.is_number: 886 path_text = f"[{path_text}]" 887 try: 888 return parse_json_path(path_text, self) 889 except ParseError as e: 890 logger.warning(f"Invalid JSON path syntax. {str(e)}") 891 892 return path
942def if_sql( 943 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 944) -> t.Callable[[Generator, exp.If], str]: 945 def _if_sql(self: Generator, expression: exp.If) -> str: 946 return self.func( 947 name, 948 expression.this, 949 expression.args.get("true"), 950 expression.args.get("false") or false_value, 951 ) 952 953 return _if_sql
956def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 957 this = expression.this 958 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 959 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 960 961 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1027def str_position_sql( 1028 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1029) -> str: 1030 this = self.sql(expression, "this") 1031 substr = self.sql(expression, "substr") 1032 position = self.sql(expression, "position") 1033 instance = expression.args.get("instance") if generate_instance else None 1034 position_offset = "" 1035 1036 if position: 1037 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1038 this = self.func("SUBSTR", this, position) 1039 position_offset = f" + {position} - 1" 1040 1041 return self.func("STRPOS", this, substr, instance) + position_offset
1050def var_map_sql( 1051 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1052) -> str: 1053 keys = expression.args["keys"] 1054 values = expression.args["values"] 1055 1056 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1057 self.unsupported("Cannot convert array columns into map.") 1058 return self.func(map_func_name, keys, values) 1059 1060 args = [] 1061 for key, value in zip(keys.expressions, values.expressions): 1062 args.append(self.sql(key)) 1063 args.append(self.sql(value)) 1064 1065 return self.func(map_func_name, *args)
1068def build_formatted_time( 1069 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1070) -> t.Callable[[t.List], E]: 1071 """Helper used for time expressions. 1072 1073 Args: 1074 exp_class: the expression class to instantiate. 1075 dialect: target sql dialect. 1076 default: the default format, True being time. 1077 1078 Returns: 1079 A callable that can be used to return the appropriately formatted time expression. 1080 """ 1081 1082 def _builder(args: t.List): 1083 return exp_class( 1084 this=seq_get(args, 0), 1085 format=Dialect[dialect].format_time( 1086 seq_get(args, 1) 1087 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1088 ), 1089 ) 1090 1091 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1094def time_format( 1095 dialect: DialectType = None, 1096) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1097 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1098 """ 1099 Returns the time format for a given expression, unless it's equivalent 1100 to the default time format of the dialect of interest. 1101 """ 1102 time_format = self.format_time(expression) 1103 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1104 1105 return _time_format
1108def build_date_delta( 1109 exp_class: t.Type[E], 1110 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1111 default_unit: t.Optional[str] = "DAY", 1112) -> t.Callable[[t.List], E]: 1113 def _builder(args: t.List) -> E: 1114 unit_based = len(args) == 3 1115 this = args[2] if unit_based else seq_get(args, 0) 1116 unit = None 1117 if unit_based or default_unit: 1118 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1119 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1120 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1121 1122 return _builder
1125def build_date_delta_with_interval( 1126 expression_class: t.Type[E], 1127) -> t.Callable[[t.List], t.Optional[E]]: 1128 def _builder(args: t.List) -> t.Optional[E]: 1129 if len(args) < 2: 1130 return None 1131 1132 interval = args[1] 1133 1134 if not isinstance(interval, exp.Interval): 1135 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1136 1137 expression = interval.this 1138 if expression and expression.is_string: 1139 expression = exp.Literal.number(expression.this) 1140 1141 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1142 1143 return _builder
1146def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1147 unit = seq_get(args, 0) 1148 this = seq_get(args, 1) 1149 1150 if isinstance(this, exp.Cast) and this.is_type("date"): 1151 return exp.DateTrunc(unit=unit, this=this) 1152 return exp.TimestampTrunc(this=this, unit=unit)
1155def date_add_interval_sql( 1156 data_type: str, kind: str 1157) -> t.Callable[[Generator, exp.Expression], str]: 1158 def func(self: Generator, expression: exp.Expression) -> str: 1159 this = self.sql(expression, "this") 1160 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1161 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1162 1163 return func
1166def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1167 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1168 args = [unit_to_str(expression), expression.this] 1169 if zone: 1170 args.append(expression.args.get("zone")) 1171 return self.func("DATE_TRUNC", *args) 1172 1173 return _timestamptrunc_sql
1176def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1177 zone = expression.args.get("zone") 1178 if not zone: 1179 from sqlglot.optimizer.annotate_types import annotate_types 1180 1181 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1182 return self.sql(exp.cast(expression.this, target_type)) 1183 if zone.name.lower() in TIMEZONES: 1184 return self.sql( 1185 exp.AtTimeZone( 1186 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1187 zone=zone, 1188 ) 1189 ) 1190 return self.func("TIMESTAMP", expression.this, zone)
1193def no_time_sql(self: Generator, expression: exp.Time) -> str: 1194 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1195 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1196 expr = exp.cast( 1197 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1198 ) 1199 return self.sql(expr)
1202def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1203 this = expression.this 1204 expr = expression.expression 1205 1206 if expr.name.lower() in TIMEZONES: 1207 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1208 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1209 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1210 return self.sql(this) 1211 1212 this = exp.cast(this, exp.DataType.Type.DATE) 1213 expr = exp.cast(expr, exp.DataType.Type.TIME) 1214 1215 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1256def encode_decode_sql( 1257 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1258) -> str: 1259 charset = expression.args.get("charset") 1260 if charset and charset.name.lower() != "utf-8": 1261 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1262 1263 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1276def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1277 cond = expression.this 1278 1279 if isinstance(expression.this, exp.Distinct): 1280 cond = expression.this.expressions[0] 1281 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1282 1283 return self.func("sum", exp.func("if", cond, 1, 0))
1286def trim_sql(self: Generator, expression: exp.Trim) -> str: 1287 target = self.sql(expression, "this") 1288 trim_type = self.sql(expression, "position") 1289 remove_chars = self.sql(expression, "expression") 1290 collation = self.sql(expression, "collation") 1291 1292 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1293 if not remove_chars and not collation: 1294 return self.trim_sql(expression) 1295 1296 trim_type = f"{trim_type} " if trim_type else "" 1297 remove_chars = f"{remove_chars} " if remove_chars else "" 1298 from_part = "FROM " if trim_type or remove_chars else "" 1299 collation = f" COLLATE {collation}" if collation else "" 1300 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1321def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1322 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1323 if bad_args: 1324 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1325 1326 return self.func( 1327 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1328 )
1331def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1332 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1333 if bad_args: 1334 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1335 1336 return self.func( 1337 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1338 )
1341def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1342 names = [] 1343 for agg in aggregations: 1344 if isinstance(agg, exp.Alias): 1345 names.append(agg.alias) 1346 else: 1347 """ 1348 This case corresponds to aggregations without aliases being used as suffixes 1349 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1350 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1351 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1352 """ 1353 agg_all_unquoted = agg.transform( 1354 lambda node: ( 1355 exp.Identifier(this=node.name, quoted=False) 1356 if isinstance(node, exp.Identifier) 1357 else node 1358 ) 1359 ) 1360 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1361 1362 return names
1402def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1403 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1404 if expression.args.get("count"): 1405 self.unsupported(f"Only two arguments are supported in function {name}.") 1406 1407 return self.func(name, expression.this, expression.expression) 1408 1409 return _arg_max_or_min_sql
1412def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1413 this = expression.this.copy() 1414 1415 return_type = expression.return_type 1416 if return_type.is_type(exp.DataType.Type.DATE): 1417 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1418 # can truncate timestamp strings, because some dialects can't cast them to DATE 1419 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1420 1421 expression.this.replace(exp.cast(this, return_type)) 1422 return expression
1425def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1426 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1427 if cast and isinstance(expression, exp.TsOrDsAdd): 1428 expression = ts_or_ds_add_cast(expression) 1429 1430 return self.func( 1431 name, 1432 unit_to_var(expression), 1433 expression.expression, 1434 expression.this, 1435 ) 1436 1437 return _delta_sql
1440def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1441 unit = expression.args.get("unit") 1442 1443 if isinstance(unit, exp.Placeholder): 1444 return unit 1445 if unit: 1446 return exp.Literal.string(unit.name) 1447 return exp.Literal.string(default) if default else None
1477def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1478 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1479 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1480 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1481 1482 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1485def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1486 """Remove table refs from columns in when statements.""" 1487 alias = expression.this.args.get("alias") 1488 1489 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1490 return self.dialect.normalize_identifier(identifier).name if identifier else None 1491 1492 targets = {normalize(expression.this.this)} 1493 1494 if alias: 1495 targets.add(normalize(alias.this)) 1496 1497 for when in expression.expressions: 1498 when.transform( 1499 lambda node: ( 1500 exp.column(node.this) 1501 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1502 else node 1503 ), 1504 copy=False, 1505 ) 1506 1507 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1510def build_json_extract_path( 1511 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1512) -> t.Callable[[t.List], F]: 1513 def _builder(args: t.List) -> F: 1514 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1515 for arg in args[1:]: 1516 if not isinstance(arg, exp.Literal): 1517 # We use the fallback parser because we can't really transpile non-literals safely 1518 return expr_type.from_arg_list(args) 1519 1520 text = arg.name 1521 if is_int(text): 1522 index = int(text) 1523 segments.append( 1524 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1525 ) 1526 else: 1527 segments.append(exp.JSONPathKey(this=text)) 1528 1529 # This is done to avoid failing in the expression validator due to the arg count 1530 del args[2:] 1531 return expr_type( 1532 this=seq_get(args, 0), 1533 expression=exp.JSONPath(expressions=segments), 1534 only_json_types=arrow_req_json_type, 1535 ) 1536 1537 return _builder
1540def json_extract_segments( 1541 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1542) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1543 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1544 path = expression.expression 1545 if not isinstance(path, exp.JSONPath): 1546 return rename_func(name)(self, expression) 1547 1548 segments = [] 1549 for segment in path.expressions: 1550 path = self.sql(segment) 1551 if path: 1552 if isinstance(segment, exp.JSONPathPart) and ( 1553 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1554 ): 1555 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1556 1557 segments.append(path) 1558 1559 if op: 1560 return f" {op} ".join([self.sql(expression.this), *segments]) 1561 return self.func(name, expression.this, *segments) 1562 1563 return _json_extract_segments
1573def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1574 cond = expression.expression 1575 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1576 alias = cond.expressions[0] 1577 cond = cond.this 1578 elif isinstance(cond, exp.Predicate): 1579 alias = "_u" 1580 else: 1581 self.unsupported("Unsupported filter condition") 1582 return "" 1583 1584 unnest = exp.Unnest(expressions=[expression.this]) 1585 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1586 return self.sql(exp.Array(expressions=[filtered]))
1598def build_default_decimal_type( 1599 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1600) -> t.Callable[[exp.DataType], exp.DataType]: 1601 def _builder(dtype: exp.DataType) -> exp.DataType: 1602 if dtype.expressions or precision is None: 1603 return dtype 1604 1605 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1606 return exp.DataType.build(f"DECIMAL({params})") 1607 1608 return _builder
1611def build_timestamp_from_parts(args: t.List) -> exp.Func: 1612 if len(args) == 2: 1613 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1614 # so we parse this into Anonymous for now instead of introducing complexity 1615 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1616 1617 return exp.TimestampFromParts.from_arg_list(args)
1624def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1625 start = expression.args.get("start") 1626 end = expression.args.get("end") 1627 step = expression.args.get("step") 1628 1629 if isinstance(start, exp.Cast): 1630 target_type = start.to 1631 elif isinstance(end, exp.Cast): 1632 target_type = end.to 1633 else: 1634 target_type = None 1635 1636 if start and end and target_type and target_type.is_type("date", "timestamp"): 1637 if isinstance(start, exp.Cast) and target_type is start.to: 1638 end = exp.cast(end, target_type) 1639 else: 1640 start = exp.cast(start, target_type) 1641 1642 return self.func("SEQUENCE", start, end, step)