sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "clickhouse"): 191 klass.generator_class.SUPPORTS_NULLABLE_TYPES = False 192 193 if enum not in ("", "athena", "presto", "trino"): 194 klass.generator_class.TRY_SUPPORTED = False 195 klass.generator_class.SUPPORTS_UESCAPE = False 196 197 if enum not in ("", "databricks", "hive", "spark", "spark2"): 198 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 199 for modifier in ("cluster", "distribute", "sort"): 200 modifier_transforms.pop(modifier, None) 201 202 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 203 204 if enum not in ("", "doris", "mysql"): 205 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 209 TokenType.STRAIGHT_JOIN, 210 } 211 212 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 213 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 214 TokenType.ANTI, 215 TokenType.SEMI, 216 } 217 218 return klass 219 220 221class Dialect(metaclass=_Dialect): 222 INDEX_OFFSET = 0 223 """The base index offset for arrays.""" 224 225 WEEK_OFFSET = 0 226 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 227 228 UNNEST_COLUMN_ONLY = False 229 """Whether `UNNEST` table aliases are treated as column aliases.""" 230 231 ALIAS_POST_TABLESAMPLE = False 232 """Whether the table alias comes after tablesample.""" 233 234 TABLESAMPLE_SIZE_IS_PERCENT = False 235 """Whether a size in the table sample clause represents percentage.""" 236 237 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 238 """Specifies the strategy according to which identifiers should be normalized.""" 239 240 IDENTIFIERS_CAN_START_WITH_DIGIT = False 241 """Whether an unquoted identifier can start with a digit.""" 242 243 DPIPE_IS_STRING_CONCAT = True 244 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 245 246 STRICT_STRING_CONCAT = False 247 """Whether `CONCAT`'s arguments must be strings.""" 248 249 SUPPORTS_USER_DEFINED_TYPES = True 250 """Whether user-defined data types are supported.""" 251 252 SUPPORTS_SEMI_ANTI_JOIN = True 253 """Whether `SEMI` or `ANTI` joins are supported.""" 254 255 SUPPORTS_COLUMN_JOIN_MARKS = False 256 """Whether the old-style outer join (+) syntax is supported.""" 257 258 COPY_PARAMS_ARE_CSV = True 259 """Separator of COPY statement parameters.""" 260 261 NORMALIZE_FUNCTIONS: bool | str = "upper" 262 """ 263 Determines how function names are going to be normalized. 264 Possible values: 265 "upper" or True: Convert names to uppercase. 266 "lower": Convert names to lowercase. 267 False: Disables function name normalization. 268 """ 269 270 LOG_BASE_FIRST: t.Optional[bool] = True 271 """ 272 Whether the base comes first in the `LOG` function. 273 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 274 """ 275 276 NULL_ORDERING = "nulls_are_small" 277 """ 278 Default `NULL` ordering method to use if not explicitly set. 279 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 280 """ 281 282 TYPED_DIVISION = False 283 """ 284 Whether the behavior of `a / b` depends on the types of `a` and `b`. 285 False means `a / b` is always float division. 286 True means `a / b` is integer division if both `a` and `b` are integers. 287 """ 288 289 SAFE_DIVISION = False 290 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 291 292 CONCAT_COALESCE = False 293 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 294 295 HEX_LOWERCASE = False 296 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 297 298 DATE_FORMAT = "'%Y-%m-%d'" 299 DATEINT_FORMAT = "'%Y%m%d'" 300 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 301 302 TIME_MAPPING: t.Dict[str, str] = {} 303 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 304 305 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 306 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 307 FORMAT_MAPPING: t.Dict[str, str] = {} 308 """ 309 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 310 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 311 """ 312 313 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 314 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 315 316 PSEUDOCOLUMNS: t.Set[str] = set() 317 """ 318 Columns that are auto-generated by the engine corresponding to this dialect. 319 For example, such columns may be excluded from `SELECT *` queries. 320 """ 321 322 PREFER_CTE_ALIAS_COLUMN = False 323 """ 324 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 325 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 326 any projection aliases in the subquery. 327 328 For example, 329 WITH y(c) AS ( 330 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 333 will be rewritten as 334 335 WITH y(c) AS ( 336 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 337 ) SELECT c FROM y; 338 """ 339 340 COPY_PARAMS_ARE_CSV = True 341 """ 342 Whether COPY statement parameters are separated by comma or whitespace 343 """ 344 345 FORCE_EARLY_ALIAS_REF_EXPANSION = False 346 """ 347 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 348 349 For example: 350 WITH data AS ( 351 SELECT 352 1 AS id, 353 2 AS my_id 354 ) 355 SELECT 356 id AS my_id 357 FROM 358 data 359 WHERE 360 my_id = 1 361 GROUP BY 362 my_id, 363 HAVING 364 my_id = 1 365 366 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 367 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 368 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 369 - Clickhouse, which will forward the alias across the query i.e it resolves 370 to "WHERE id = 1 GROUP BY id HAVING id = 1" 371 """ 372 373 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 374 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 375 376 SUPPORTS_ORDER_BY_ALL = False 377 """ 378 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 379 """ 380 381 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 382 """ 383 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 384 as the former is of type INT[] vs the latter which is SUPER 385 """ 386 387 SUPPORTS_FIXED_SIZE_ARRAYS = False 388 """ 389 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 390 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 391 be interpreted as a subscript/index operator. 392 """ 393 394 STRICT_JSON_PATH_SYNTAX = True 395 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 396 397 ON_CONDITION_EMPTY_BEFORE_ERROR = True 398 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 399 400 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 401 """Whether ArrayAgg needs to filter NULL values.""" 402 403 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 404 exp.Except: True, 405 exp.Intersect: True, 406 exp.Union: True, 407 } 408 """ 409 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 410 must be explicitly specified. 411 """ 412 413 CREATABLE_KIND_MAPPING: dict[str, str] = {} 414 """ 415 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 416 equivalent of CREATE SCHEMA is CREATE DATABASE. 417 """ 418 419 # --- Autofilled --- 420 421 tokenizer_class = Tokenizer 422 jsonpath_tokenizer_class = JSONPathTokenizer 423 parser_class = Parser 424 generator_class = Generator 425 426 # A trie of the time_mapping keys 427 TIME_TRIE: t.Dict = {} 428 FORMAT_TRIE: t.Dict = {} 429 430 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 431 INVERSE_TIME_TRIE: t.Dict = {} 432 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 433 INVERSE_FORMAT_TRIE: t.Dict = {} 434 435 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 436 437 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 438 439 # Delimiters for string literals and identifiers 440 QUOTE_START = "'" 441 QUOTE_END = "'" 442 IDENTIFIER_START = '"' 443 IDENTIFIER_END = '"' 444 445 # Delimiters for bit, hex, byte and unicode literals 446 BIT_START: t.Optional[str] = None 447 BIT_END: t.Optional[str] = None 448 HEX_START: t.Optional[str] = None 449 HEX_END: t.Optional[str] = None 450 BYTE_START: t.Optional[str] = None 451 BYTE_END: t.Optional[str] = None 452 UNICODE_START: t.Optional[str] = None 453 UNICODE_END: t.Optional[str] = None 454 455 DATE_PART_MAPPING = { 456 "Y": "YEAR", 457 "YY": "YEAR", 458 "YYY": "YEAR", 459 "YYYY": "YEAR", 460 "YR": "YEAR", 461 "YEARS": "YEAR", 462 "YRS": "YEAR", 463 "MM": "MONTH", 464 "MON": "MONTH", 465 "MONS": "MONTH", 466 "MONTHS": "MONTH", 467 "D": "DAY", 468 "DD": "DAY", 469 "DAYS": "DAY", 470 "DAYOFMONTH": "DAY", 471 "DAY OF WEEK": "DAYOFWEEK", 472 "WEEKDAY": "DAYOFWEEK", 473 "DOW": "DAYOFWEEK", 474 "DW": "DAYOFWEEK", 475 "WEEKDAY_ISO": "DAYOFWEEKISO", 476 "DOW_ISO": "DAYOFWEEKISO", 477 "DW_ISO": "DAYOFWEEKISO", 478 "DAY OF YEAR": "DAYOFYEAR", 479 "DOY": "DAYOFYEAR", 480 "DY": "DAYOFYEAR", 481 "W": "WEEK", 482 "WK": "WEEK", 483 "WEEKOFYEAR": "WEEK", 484 "WOY": "WEEK", 485 "WY": "WEEK", 486 "WEEK_ISO": "WEEKISO", 487 "WEEKOFYEARISO": "WEEKISO", 488 "WEEKOFYEAR_ISO": "WEEKISO", 489 "Q": "QUARTER", 490 "QTR": "QUARTER", 491 "QTRS": "QUARTER", 492 "QUARTERS": "QUARTER", 493 "H": "HOUR", 494 "HH": "HOUR", 495 "HR": "HOUR", 496 "HOURS": "HOUR", 497 "HRS": "HOUR", 498 "M": "MINUTE", 499 "MI": "MINUTE", 500 "MIN": "MINUTE", 501 "MINUTES": "MINUTE", 502 "MINS": "MINUTE", 503 "S": "SECOND", 504 "SEC": "SECOND", 505 "SECONDS": "SECOND", 506 "SECS": "SECOND", 507 "MS": "MILLISECOND", 508 "MSEC": "MILLISECOND", 509 "MSECS": "MILLISECOND", 510 "MSECOND": "MILLISECOND", 511 "MSECONDS": "MILLISECOND", 512 "MILLISEC": "MILLISECOND", 513 "MILLISECS": "MILLISECOND", 514 "MILLISECON": "MILLISECOND", 515 "MILLISECONDS": "MILLISECOND", 516 "US": "MICROSECOND", 517 "USEC": "MICROSECOND", 518 "USECS": "MICROSECOND", 519 "MICROSEC": "MICROSECOND", 520 "MICROSECS": "MICROSECOND", 521 "USECOND": "MICROSECOND", 522 "USECONDS": "MICROSECOND", 523 "MICROSECONDS": "MICROSECOND", 524 "NS": "NANOSECOND", 525 "NSEC": "NANOSECOND", 526 "NANOSEC": "NANOSECOND", 527 "NSECOND": "NANOSECOND", 528 "NSECONDS": "NANOSECOND", 529 "NANOSECS": "NANOSECOND", 530 "EPOCH_SECOND": "EPOCH", 531 "EPOCH_SECONDS": "EPOCH", 532 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 533 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 534 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 535 "TZH": "TIMEZONE_HOUR", 536 "TZM": "TIMEZONE_MINUTE", 537 "DEC": "DECADE", 538 "DECS": "DECADE", 539 "DECADES": "DECADE", 540 "MIL": "MILLENIUM", 541 "MILS": "MILLENIUM", 542 "MILLENIA": "MILLENIUM", 543 "C": "CENTURY", 544 "CENT": "CENTURY", 545 "CENTS": "CENTURY", 546 "CENTURIES": "CENTURY", 547 } 548 549 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 550 exp.DataType.Type.BIGINT: { 551 exp.ApproxDistinct, 552 exp.ArraySize, 553 exp.Length, 554 }, 555 exp.DataType.Type.BOOLEAN: { 556 exp.Between, 557 exp.Boolean, 558 exp.In, 559 exp.RegexpLike, 560 }, 561 exp.DataType.Type.DATE: { 562 exp.CurrentDate, 563 exp.Date, 564 exp.DateFromParts, 565 exp.DateStrToDate, 566 exp.DiToDate, 567 exp.StrToDate, 568 exp.TimeStrToDate, 569 exp.TsOrDsToDate, 570 }, 571 exp.DataType.Type.DATETIME: { 572 exp.CurrentDatetime, 573 exp.Datetime, 574 exp.DatetimeAdd, 575 exp.DatetimeSub, 576 }, 577 exp.DataType.Type.DOUBLE: { 578 exp.ApproxQuantile, 579 exp.Avg, 580 exp.Div, 581 exp.Exp, 582 exp.Ln, 583 exp.Log, 584 exp.Pow, 585 exp.Quantile, 586 exp.Round, 587 exp.SafeDivide, 588 exp.Sqrt, 589 exp.Stddev, 590 exp.StddevPop, 591 exp.StddevSamp, 592 exp.Variance, 593 exp.VariancePop, 594 }, 595 exp.DataType.Type.INT: { 596 exp.Ceil, 597 exp.DatetimeDiff, 598 exp.DateDiff, 599 exp.TimestampDiff, 600 exp.TimeDiff, 601 exp.DateToDi, 602 exp.Levenshtein, 603 exp.Sign, 604 exp.StrPosition, 605 exp.TsOrDiToDi, 606 }, 607 exp.DataType.Type.JSON: { 608 exp.ParseJSON, 609 }, 610 exp.DataType.Type.TIME: { 611 exp.Time, 612 }, 613 exp.DataType.Type.TIMESTAMP: { 614 exp.CurrentTime, 615 exp.CurrentTimestamp, 616 exp.StrToTime, 617 exp.TimeAdd, 618 exp.TimeStrToTime, 619 exp.TimeSub, 620 exp.TimestampAdd, 621 exp.TimestampSub, 622 exp.UnixToTime, 623 }, 624 exp.DataType.Type.TINYINT: { 625 exp.Day, 626 exp.Month, 627 exp.Week, 628 exp.Year, 629 exp.Quarter, 630 }, 631 exp.DataType.Type.VARCHAR: { 632 exp.ArrayConcat, 633 exp.Concat, 634 exp.ConcatWs, 635 exp.DateToDateStr, 636 exp.GroupConcat, 637 exp.Initcap, 638 exp.Lower, 639 exp.Substring, 640 exp.TimeToStr, 641 exp.TimeToTimeStr, 642 exp.Trim, 643 exp.TsOrDsToDateStr, 644 exp.UnixToStr, 645 exp.UnixToTimeStr, 646 exp.Upper, 647 }, 648 } 649 650 ANNOTATORS: AnnotatorsType = { 651 **{ 652 expr_type: lambda self, e: self._annotate_unary(e) 653 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 654 }, 655 **{ 656 expr_type: lambda self, e: self._annotate_binary(e) 657 for expr_type in subclasses(exp.__name__, exp.Binary) 658 }, 659 **{ 660 expr_type: _annotate_with_type_lambda(data_type) 661 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 662 for expr_type in expressions 663 }, 664 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 665 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 666 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 667 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 668 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 669 exp.Bracket: lambda self, e: self._annotate_bracket(e), 670 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 671 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 672 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 673 exp.Count: lambda self, e: self._annotate_with_type( 674 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 675 ), 676 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 677 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 678 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 679 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 680 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 681 exp.Div: lambda self, e: self._annotate_div(e), 682 exp.Dot: lambda self, e: self._annotate_dot(e), 683 exp.Explode: lambda self, e: self._annotate_explode(e), 684 exp.Extract: lambda self, e: self._annotate_extract(e), 685 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 686 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 687 e, exp.DataType.build("ARRAY<DATE>") 688 ), 689 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 690 e, exp.DataType.build("ARRAY<TIMESTAMP>") 691 ), 692 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 693 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 694 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 695 exp.Literal: lambda self, e: self._annotate_literal(e), 696 exp.Map: lambda self, e: self._annotate_map(e), 697 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 698 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 699 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 700 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 701 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 702 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 703 exp.Struct: lambda self, e: self._annotate_struct(e), 704 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 705 exp.Timestamp: lambda self, e: self._annotate_with_type( 706 e, 707 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 708 ), 709 exp.ToMap: lambda self, e: self._annotate_to_map(e), 710 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 711 exp.Unnest: lambda self, e: self._annotate_unnest(e), 712 exp.VarMap: lambda self, e: self._annotate_map(e), 713 } 714 715 @classmethod 716 def get_or_raise(cls, dialect: DialectType) -> Dialect: 717 """ 718 Look up a dialect in the global dialect registry and return it if it exists. 719 720 Args: 721 dialect: The target dialect. If this is a string, it can be optionally followed by 722 additional key-value pairs that are separated by commas and are used to specify 723 dialect settings, such as whether the dialect's identifiers are case-sensitive. 724 725 Example: 726 >>> dialect = dialect_class = get_or_raise("duckdb") 727 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 728 729 Returns: 730 The corresponding Dialect instance. 731 """ 732 733 if not dialect: 734 return cls() 735 if isinstance(dialect, _Dialect): 736 return dialect() 737 if isinstance(dialect, Dialect): 738 return dialect 739 if isinstance(dialect, str): 740 try: 741 dialect_name, *kv_strings = dialect.split(",") 742 kv_pairs = (kv.split("=") for kv in kv_strings) 743 kwargs = {} 744 for pair in kv_pairs: 745 key = pair[0].strip() 746 value: t.Union[bool | str | None] = None 747 748 if len(pair) == 1: 749 # Default initialize standalone settings to True 750 value = True 751 elif len(pair) == 2: 752 value = pair[1].strip() 753 754 # Coerce the value to boolean if it matches to the truthy/falsy values below 755 value_lower = value.lower() 756 if value_lower in ("true", "1"): 757 value = True 758 elif value_lower in ("false", "0"): 759 value = False 760 761 kwargs[key] = value 762 763 except ValueError: 764 raise ValueError( 765 f"Invalid dialect format: '{dialect}'. " 766 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 767 ) 768 769 result = cls.get(dialect_name.strip()) 770 if not result: 771 from difflib import get_close_matches 772 773 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 774 if similar: 775 similar = f" Did you mean {similar}?" 776 777 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 778 779 return result(**kwargs) 780 781 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 782 783 @classmethod 784 def format_time( 785 cls, expression: t.Optional[str | exp.Expression] 786 ) -> t.Optional[exp.Expression]: 787 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 788 if isinstance(expression, str): 789 return exp.Literal.string( 790 # the time formats are quoted 791 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 792 ) 793 794 if expression and expression.is_string: 795 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 796 797 return expression 798 799 def __init__(self, **kwargs) -> None: 800 normalization_strategy = kwargs.pop("normalization_strategy", None) 801 802 if normalization_strategy is None: 803 self.normalization_strategy = self.NORMALIZATION_STRATEGY 804 else: 805 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 806 807 self.settings = kwargs 808 809 def __eq__(self, other: t.Any) -> bool: 810 # Does not currently take dialect state into account 811 return type(self) == other 812 813 def __hash__(self) -> int: 814 # Does not currently take dialect state into account 815 return hash(type(self)) 816 817 def normalize_identifier(self, expression: E) -> E: 818 """ 819 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 820 821 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 822 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 823 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 824 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 825 826 There are also dialects like Spark, which are case-insensitive even when quotes are 827 present, and dialects like MySQL, whose resolution rules match those employed by the 828 underlying operating system, for example they may always be case-sensitive in Linux. 829 830 Finally, the normalization behavior of some engines can even be controlled through flags, 831 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 832 833 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 834 that it can analyze queries in the optimizer and successfully capture their semantics. 835 """ 836 if ( 837 isinstance(expression, exp.Identifier) 838 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 839 and ( 840 not expression.quoted 841 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 842 ) 843 ): 844 expression.set( 845 "this", 846 ( 847 expression.this.upper() 848 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 849 else expression.this.lower() 850 ), 851 ) 852 853 return expression 854 855 def case_sensitive(self, text: str) -> bool: 856 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 857 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 858 return False 859 860 unsafe = ( 861 str.islower 862 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 863 else str.isupper 864 ) 865 return any(unsafe(char) for char in text) 866 867 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 868 """Checks if text can be identified given an identify option. 869 870 Args: 871 text: The text to check. 872 identify: 873 `"always"` or `True`: Always returns `True`. 874 `"safe"`: Only returns `True` if the identifier is case-insensitive. 875 876 Returns: 877 Whether the given text can be identified. 878 """ 879 if identify is True or identify == "always": 880 return True 881 882 if identify == "safe": 883 return not self.case_sensitive(text) 884 885 return False 886 887 def quote_identifier(self, expression: E, identify: bool = True) -> E: 888 """ 889 Adds quotes to a given identifier. 890 891 Args: 892 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 893 identify: If set to `False`, the quotes will only be added if the identifier is deemed 894 "unsafe", with respect to its characters and this dialect's normalization strategy. 895 """ 896 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 897 name = expression.this 898 expression.set( 899 "quoted", 900 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 901 ) 902 903 return expression 904 905 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 906 if isinstance(path, exp.Literal): 907 path_text = path.name 908 if path.is_number: 909 path_text = f"[{path_text}]" 910 try: 911 return parse_json_path(path_text, self) 912 except ParseError as e: 913 if self.STRICT_JSON_PATH_SYNTAX: 914 logger.warning(f"Invalid JSON path syntax. {str(e)}") 915 916 return path 917 918 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 919 return self.parser(**opts).parse(self.tokenize(sql), sql) 920 921 def parse_into( 922 self, expression_type: exp.IntoType, sql: str, **opts 923 ) -> t.List[t.Optional[exp.Expression]]: 924 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 925 926 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 927 return self.generator(**opts).generate(expression, copy=copy) 928 929 def transpile(self, sql: str, **opts) -> t.List[str]: 930 return [ 931 self.generate(expression, copy=False, **opts) if expression else "" 932 for expression in self.parse(sql) 933 ] 934 935 def tokenize(self, sql: str) -> t.List[Token]: 936 return self.tokenizer.tokenize(sql) 937 938 @property 939 def tokenizer(self) -> Tokenizer: 940 return self.tokenizer_class(dialect=self) 941 942 @property 943 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 944 return self.jsonpath_tokenizer_class(dialect=self) 945 946 def parser(self, **opts) -> Parser: 947 return self.parser_class(dialect=self, **opts) 948 949 def generator(self, **opts) -> Generator: 950 return self.generator_class(dialect=self, **opts) 951 952 953DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 954 955 956def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 957 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 958 959 960def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 961 if expression.args.get("accuracy"): 962 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 963 return self.func("APPROX_COUNT_DISTINCT", expression.this) 964 965 966def if_sql( 967 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 968) -> t.Callable[[Generator, exp.If], str]: 969 def _if_sql(self: Generator, expression: exp.If) -> str: 970 return self.func( 971 name, 972 expression.this, 973 expression.args.get("true"), 974 expression.args.get("false") or false_value, 975 ) 976 977 return _if_sql 978 979 980def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 981 this = expression.this 982 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 983 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 984 985 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 986 987 988def inline_array_sql(self: Generator, expression: exp.Array) -> str: 989 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 990 991 992def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 993 elem = seq_get(expression.expressions, 0) 994 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 995 return self.func("ARRAY", elem) 996 return inline_array_sql(self, expression) 997 998 999def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1000 return self.like_sql( 1001 exp.Like( 1002 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1003 ) 1004 ) 1005 1006 1007def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1008 zone = self.sql(expression, "this") 1009 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1010 1011 1012def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1013 if expression.args.get("recursive"): 1014 self.unsupported("Recursive CTEs are unsupported") 1015 expression.args["recursive"] = False 1016 return self.with_sql(expression) 1017 1018 1019def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 1020 n = self.sql(expression, "this") 1021 d = self.sql(expression, "expression") 1022 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 1023 1024 1025def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1026 self.unsupported("TABLESAMPLE unsupported") 1027 return self.sql(expression.this) 1028 1029 1030def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1031 self.unsupported("PIVOT unsupported") 1032 return "" 1033 1034 1035def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1036 return self.cast_sql(expression) 1037 1038 1039def no_comment_column_constraint_sql( 1040 self: Generator, expression: exp.CommentColumnConstraint 1041) -> str: 1042 self.unsupported("CommentColumnConstraint unsupported") 1043 return "" 1044 1045 1046def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1047 self.unsupported("MAP_FROM_ENTRIES unsupported") 1048 return "" 1049 1050 1051def property_sql(self: Generator, expression: exp.Property) -> str: 1052 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1053 1054 1055def str_position_sql( 1056 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1057) -> str: 1058 this = self.sql(expression, "this") 1059 substr = self.sql(expression, "substr") 1060 position = self.sql(expression, "position") 1061 instance = expression.args.get("instance") if generate_instance else None 1062 position_offset = "" 1063 1064 if position: 1065 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1066 this = self.func("SUBSTR", this, position) 1067 position_offset = f" + {position} - 1" 1068 1069 return self.func("STRPOS", this, substr, instance) + position_offset 1070 1071 1072def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1073 return ( 1074 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1075 ) 1076 1077 1078def var_map_sql( 1079 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1080) -> str: 1081 keys = expression.args["keys"] 1082 values = expression.args["values"] 1083 1084 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1085 self.unsupported("Cannot convert array columns into map.") 1086 return self.func(map_func_name, keys, values) 1087 1088 args = [] 1089 for key, value in zip(keys.expressions, values.expressions): 1090 args.append(self.sql(key)) 1091 args.append(self.sql(value)) 1092 1093 return self.func(map_func_name, *args) 1094 1095 1096def build_formatted_time( 1097 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1098) -> t.Callable[[t.List], E]: 1099 """Helper used for time expressions. 1100 1101 Args: 1102 exp_class: the expression class to instantiate. 1103 dialect: target sql dialect. 1104 default: the default format, True being time. 1105 1106 Returns: 1107 A callable that can be used to return the appropriately formatted time expression. 1108 """ 1109 1110 def _builder(args: t.List): 1111 return exp_class( 1112 this=seq_get(args, 0), 1113 format=Dialect[dialect].format_time( 1114 seq_get(args, 1) 1115 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1116 ), 1117 ) 1118 1119 return _builder 1120 1121 1122def time_format( 1123 dialect: DialectType = None, 1124) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1125 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1126 """ 1127 Returns the time format for a given expression, unless it's equivalent 1128 to the default time format of the dialect of interest. 1129 """ 1130 time_format = self.format_time(expression) 1131 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1132 1133 return _time_format 1134 1135 1136def build_date_delta( 1137 exp_class: t.Type[E], 1138 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1139 default_unit: t.Optional[str] = "DAY", 1140) -> t.Callable[[t.List], E]: 1141 def _builder(args: t.List) -> E: 1142 unit_based = len(args) == 3 1143 this = args[2] if unit_based else seq_get(args, 0) 1144 unit = None 1145 if unit_based or default_unit: 1146 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1147 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1148 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1149 1150 return _builder 1151 1152 1153def build_date_delta_with_interval( 1154 expression_class: t.Type[E], 1155) -> t.Callable[[t.List], t.Optional[E]]: 1156 def _builder(args: t.List) -> t.Optional[E]: 1157 if len(args) < 2: 1158 return None 1159 1160 interval = args[1] 1161 1162 if not isinstance(interval, exp.Interval): 1163 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1164 1165 expression = interval.this 1166 if expression and expression.is_string: 1167 expression = exp.Literal.number(expression.this) 1168 1169 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1170 1171 return _builder 1172 1173 1174def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1175 unit = seq_get(args, 0) 1176 this = seq_get(args, 1) 1177 1178 if isinstance(this, exp.Cast) and this.is_type("date"): 1179 return exp.DateTrunc(unit=unit, this=this) 1180 return exp.TimestampTrunc(this=this, unit=unit) 1181 1182 1183def date_add_interval_sql( 1184 data_type: str, kind: str 1185) -> t.Callable[[Generator, exp.Expression], str]: 1186 def func(self: Generator, expression: exp.Expression) -> str: 1187 this = self.sql(expression, "this") 1188 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1189 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1190 1191 return func 1192 1193 1194def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1195 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1196 args = [unit_to_str(expression), expression.this] 1197 if zone: 1198 args.append(expression.args.get("zone")) 1199 return self.func("DATE_TRUNC", *args) 1200 1201 return _timestamptrunc_sql 1202 1203 1204def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1205 zone = expression.args.get("zone") 1206 if not zone: 1207 from sqlglot.optimizer.annotate_types import annotate_types 1208 1209 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1210 return self.sql(exp.cast(expression.this, target_type)) 1211 if zone.name.lower() in TIMEZONES: 1212 return self.sql( 1213 exp.AtTimeZone( 1214 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1215 zone=zone, 1216 ) 1217 ) 1218 return self.func("TIMESTAMP", expression.this, zone) 1219 1220 1221def no_time_sql(self: Generator, expression: exp.Time) -> str: 1222 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1223 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1224 expr = exp.cast( 1225 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1226 ) 1227 return self.sql(expr) 1228 1229 1230def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1231 this = expression.this 1232 expr = expression.expression 1233 1234 if expr.name.lower() in TIMEZONES: 1235 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1236 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1237 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1238 return self.sql(this) 1239 1240 this = exp.cast(this, exp.DataType.Type.DATE) 1241 expr = exp.cast(expr, exp.DataType.Type.TIME) 1242 1243 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1244 1245 1246def locate_to_strposition(args: t.List) -> exp.Expression: 1247 return exp.StrPosition( 1248 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1249 ) 1250 1251 1252def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1253 return self.func( 1254 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1255 ) 1256 1257 1258def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1259 return self.sql( 1260 exp.Substring( 1261 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1262 ) 1263 ) 1264 1265 1266def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1267 return self.sql( 1268 exp.Substring( 1269 this=expression.this, 1270 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1271 ) 1272 ) 1273 1274 1275def timestrtotime_sql( 1276 self: Generator, 1277 expression: exp.TimeStrToTime, 1278 include_precision: bool = False, 1279) -> str: 1280 datatype = exp.DataType.build( 1281 exp.DataType.Type.TIMESTAMPTZ 1282 if expression.args.get("zone") 1283 else exp.DataType.Type.TIMESTAMP 1284 ) 1285 1286 if isinstance(expression.this, exp.Literal) and include_precision: 1287 precision = subsecond_precision(expression.this.name) 1288 if precision > 0: 1289 datatype = exp.DataType.build( 1290 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1291 ) 1292 1293 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1294 1295 1296def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1297 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1298 1299 1300# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1301def encode_decode_sql( 1302 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1303) -> str: 1304 charset = expression.args.get("charset") 1305 if charset and charset.name.lower() != "utf-8": 1306 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1307 1308 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1309 1310 1311def min_or_least(self: Generator, expression: exp.Min) -> str: 1312 name = "LEAST" if expression.expressions else "MIN" 1313 return rename_func(name)(self, expression) 1314 1315 1316def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1317 name = "GREATEST" if expression.expressions else "MAX" 1318 return rename_func(name)(self, expression) 1319 1320 1321def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1322 cond = expression.this 1323 1324 if isinstance(expression.this, exp.Distinct): 1325 cond = expression.this.expressions[0] 1326 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1327 1328 return self.func("sum", exp.func("if", cond, 1, 0)) 1329 1330 1331def trim_sql(self: Generator, expression: exp.Trim) -> str: 1332 target = self.sql(expression, "this") 1333 trim_type = self.sql(expression, "position") 1334 remove_chars = self.sql(expression, "expression") 1335 collation = self.sql(expression, "collation") 1336 1337 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1338 if not remove_chars: 1339 return self.trim_sql(expression) 1340 1341 trim_type = f"{trim_type} " if trim_type else "" 1342 remove_chars = f"{remove_chars} " if remove_chars else "" 1343 from_part = "FROM " if trim_type or remove_chars else "" 1344 collation = f" COLLATE {collation}" if collation else "" 1345 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1346 1347 1348def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1349 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1350 1351 1352def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1353 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1354 1355 1356def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1357 delim, *rest_args = expression.expressions 1358 return self.sql( 1359 reduce( 1360 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1361 rest_args, 1362 ) 1363 ) 1364 1365 1366def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1367 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1368 if bad_args: 1369 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1370 1371 return self.func( 1372 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1373 ) 1374 1375 1376def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1377 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1378 if bad_args: 1379 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1380 1381 return self.func( 1382 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1383 ) 1384 1385 1386def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1387 names = [] 1388 for agg in aggregations: 1389 if isinstance(agg, exp.Alias): 1390 names.append(agg.alias) 1391 else: 1392 """ 1393 This case corresponds to aggregations without aliases being used as suffixes 1394 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1395 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1396 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1397 """ 1398 agg_all_unquoted = agg.transform( 1399 lambda node: ( 1400 exp.Identifier(this=node.name, quoted=False) 1401 if isinstance(node, exp.Identifier) 1402 else node 1403 ) 1404 ) 1405 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1406 1407 return names 1408 1409 1410def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1411 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1412 1413 1414# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1415def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1416 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1417 1418 1419def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1420 return self.func("MAX", expression.this) 1421 1422 1423def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1424 a = self.sql(expression.left) 1425 b = self.sql(expression.right) 1426 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1427 1428 1429def is_parse_json(expression: exp.Expression) -> bool: 1430 return isinstance(expression, exp.ParseJSON) or ( 1431 isinstance(expression, exp.Cast) and expression.is_type("json") 1432 ) 1433 1434 1435def isnull_to_is_null(args: t.List) -> exp.Expression: 1436 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1437 1438 1439def generatedasidentitycolumnconstraint_sql( 1440 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1441) -> str: 1442 start = self.sql(expression, "start") or "1" 1443 increment = self.sql(expression, "increment") or "1" 1444 return f"IDENTITY({start}, {increment})" 1445 1446 1447def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1448 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1449 if expression.args.get("count"): 1450 self.unsupported(f"Only two arguments are supported in function {name}.") 1451 1452 return self.func(name, expression.this, expression.expression) 1453 1454 return _arg_max_or_min_sql 1455 1456 1457def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1458 this = expression.this.copy() 1459 1460 return_type = expression.return_type 1461 if return_type.is_type(exp.DataType.Type.DATE): 1462 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1463 # can truncate timestamp strings, because some dialects can't cast them to DATE 1464 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1465 1466 expression.this.replace(exp.cast(this, return_type)) 1467 return expression 1468 1469 1470def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1471 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1472 if cast and isinstance(expression, exp.TsOrDsAdd): 1473 expression = ts_or_ds_add_cast(expression) 1474 1475 return self.func( 1476 name, 1477 unit_to_var(expression), 1478 expression.expression, 1479 expression.this, 1480 ) 1481 1482 return _delta_sql 1483 1484 1485def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1486 unit = expression.args.get("unit") 1487 1488 if isinstance(unit, exp.Placeholder): 1489 return unit 1490 if unit: 1491 return exp.Literal.string(unit.name) 1492 return exp.Literal.string(default) if default else None 1493 1494 1495def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1496 unit = expression.args.get("unit") 1497 1498 if isinstance(unit, (exp.Var, exp.Placeholder)): 1499 return unit 1500 return exp.Var(this=default) if default else None 1501 1502 1503@t.overload 1504def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1505 pass 1506 1507 1508@t.overload 1509def map_date_part( 1510 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1511) -> t.Optional[exp.Expression]: 1512 pass 1513 1514 1515def map_date_part(part, dialect: DialectType = Dialect): 1516 mapped = ( 1517 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1518 ) 1519 return exp.var(mapped) if mapped else part 1520 1521 1522def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1523 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1524 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1525 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1526 1527 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1528 1529 1530def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1531 """Remove table refs from columns in when statements.""" 1532 alias = expression.this.args.get("alias") 1533 1534 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1535 return self.dialect.normalize_identifier(identifier).name if identifier else None 1536 1537 targets = {normalize(expression.this.this)} 1538 1539 if alias: 1540 targets.add(normalize(alias.this)) 1541 1542 for when in expression.expressions: 1543 # only remove the target names from the THEN clause 1544 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1545 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1546 then = when.args.get("then") 1547 if then: 1548 then.transform( 1549 lambda node: ( 1550 exp.column(node.this) 1551 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1552 else node 1553 ), 1554 copy=False, 1555 ) 1556 1557 return self.merge_sql(expression) 1558 1559 1560def build_json_extract_path( 1561 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1562) -> t.Callable[[t.List], F]: 1563 def _builder(args: t.List) -> F: 1564 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1565 for arg in args[1:]: 1566 if not isinstance(arg, exp.Literal): 1567 # We use the fallback parser because we can't really transpile non-literals safely 1568 return expr_type.from_arg_list(args) 1569 1570 text = arg.name 1571 if is_int(text): 1572 index = int(text) 1573 segments.append( 1574 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1575 ) 1576 else: 1577 segments.append(exp.JSONPathKey(this=text)) 1578 1579 # This is done to avoid failing in the expression validator due to the arg count 1580 del args[2:] 1581 return expr_type( 1582 this=seq_get(args, 0), 1583 expression=exp.JSONPath(expressions=segments), 1584 only_json_types=arrow_req_json_type, 1585 ) 1586 1587 return _builder 1588 1589 1590def json_extract_segments( 1591 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1592) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1593 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1594 path = expression.expression 1595 if not isinstance(path, exp.JSONPath): 1596 return rename_func(name)(self, expression) 1597 1598 segments = [] 1599 for segment in path.expressions: 1600 path = self.sql(segment) 1601 if path: 1602 if isinstance(segment, exp.JSONPathPart) and ( 1603 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1604 ): 1605 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1606 1607 segments.append(path) 1608 1609 if op: 1610 return f" {op} ".join([self.sql(expression.this), *segments]) 1611 return self.func(name, expression.this, *segments) 1612 1613 return _json_extract_segments 1614 1615 1616def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1617 if isinstance(expression.this, exp.JSONPathWildcard): 1618 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1619 1620 return expression.name 1621 1622 1623def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1624 cond = expression.expression 1625 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1626 alias = cond.expressions[0] 1627 cond = cond.this 1628 elif isinstance(cond, exp.Predicate): 1629 alias = "_u" 1630 else: 1631 self.unsupported("Unsupported filter condition") 1632 return "" 1633 1634 unnest = exp.Unnest(expressions=[expression.this]) 1635 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1636 return self.sql(exp.Array(expressions=[filtered])) 1637 1638 1639def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1640 return self.func( 1641 "TO_NUMBER", 1642 expression.this, 1643 expression.args.get("format"), 1644 expression.args.get("nlsparam"), 1645 ) 1646 1647 1648def build_default_decimal_type( 1649 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1650) -> t.Callable[[exp.DataType], exp.DataType]: 1651 def _builder(dtype: exp.DataType) -> exp.DataType: 1652 if dtype.expressions or precision is None: 1653 return dtype 1654 1655 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1656 return exp.DataType.build(f"DECIMAL({params})") 1657 1658 return _builder 1659 1660 1661def build_timestamp_from_parts(args: t.List) -> exp.Func: 1662 if len(args) == 2: 1663 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1664 # so we parse this into Anonymous for now instead of introducing complexity 1665 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1666 1667 return exp.TimestampFromParts.from_arg_list(args) 1668 1669 1670def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1671 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1672 1673 1674def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1675 start = expression.args.get("start") 1676 end = expression.args.get("end") 1677 step = expression.args.get("step") 1678 1679 if isinstance(start, exp.Cast): 1680 target_type = start.to 1681 elif isinstance(end, exp.Cast): 1682 target_type = end.to 1683 else: 1684 target_type = None 1685 1686 if start and end and target_type and target_type.is_type("date", "timestamp"): 1687 if isinstance(start, exp.Cast) and target_type is start.to: 1688 end = exp.cast(end, target_type) 1689 else: 1690 start = exp.cast(start, target_type) 1691 1692 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
222class Dialect(metaclass=_Dialect): 223 INDEX_OFFSET = 0 224 """The base index offset for arrays.""" 225 226 WEEK_OFFSET = 0 227 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 228 229 UNNEST_COLUMN_ONLY = False 230 """Whether `UNNEST` table aliases are treated as column aliases.""" 231 232 ALIAS_POST_TABLESAMPLE = False 233 """Whether the table alias comes after tablesample.""" 234 235 TABLESAMPLE_SIZE_IS_PERCENT = False 236 """Whether a size in the table sample clause represents percentage.""" 237 238 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 239 """Specifies the strategy according to which identifiers should be normalized.""" 240 241 IDENTIFIERS_CAN_START_WITH_DIGIT = False 242 """Whether an unquoted identifier can start with a digit.""" 243 244 DPIPE_IS_STRING_CONCAT = True 245 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 246 247 STRICT_STRING_CONCAT = False 248 """Whether `CONCAT`'s arguments must be strings.""" 249 250 SUPPORTS_USER_DEFINED_TYPES = True 251 """Whether user-defined data types are supported.""" 252 253 SUPPORTS_SEMI_ANTI_JOIN = True 254 """Whether `SEMI` or `ANTI` joins are supported.""" 255 256 SUPPORTS_COLUMN_JOIN_MARKS = False 257 """Whether the old-style outer join (+) syntax is supported.""" 258 259 COPY_PARAMS_ARE_CSV = True 260 """Separator of COPY statement parameters.""" 261 262 NORMALIZE_FUNCTIONS: bool | str = "upper" 263 """ 264 Determines how function names are going to be normalized. 265 Possible values: 266 "upper" or True: Convert names to uppercase. 267 "lower": Convert names to lowercase. 268 False: Disables function name normalization. 269 """ 270 271 LOG_BASE_FIRST: t.Optional[bool] = True 272 """ 273 Whether the base comes first in the `LOG` function. 274 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 275 """ 276 277 NULL_ORDERING = "nulls_are_small" 278 """ 279 Default `NULL` ordering method to use if not explicitly set. 280 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 281 """ 282 283 TYPED_DIVISION = False 284 """ 285 Whether the behavior of `a / b` depends on the types of `a` and `b`. 286 False means `a / b` is always float division. 287 True means `a / b` is integer division if both `a` and `b` are integers. 288 """ 289 290 SAFE_DIVISION = False 291 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 292 293 CONCAT_COALESCE = False 294 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 295 296 HEX_LOWERCASE = False 297 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 298 299 DATE_FORMAT = "'%Y-%m-%d'" 300 DATEINT_FORMAT = "'%Y%m%d'" 301 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 302 303 TIME_MAPPING: t.Dict[str, str] = {} 304 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 305 306 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 307 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 308 FORMAT_MAPPING: t.Dict[str, str] = {} 309 """ 310 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 311 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 312 """ 313 314 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 315 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 316 317 PSEUDOCOLUMNS: t.Set[str] = set() 318 """ 319 Columns that are auto-generated by the engine corresponding to this dialect. 320 For example, such columns may be excluded from `SELECT *` queries. 321 """ 322 323 PREFER_CTE_ALIAS_COLUMN = False 324 """ 325 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 326 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 327 any projection aliases in the subquery. 328 329 For example, 330 WITH y(c) AS ( 331 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 332 ) SELECT c FROM y; 333 334 will be rewritten as 335 336 WITH y(c) AS ( 337 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 338 ) SELECT c FROM y; 339 """ 340 341 COPY_PARAMS_ARE_CSV = True 342 """ 343 Whether COPY statement parameters are separated by comma or whitespace 344 """ 345 346 FORCE_EARLY_ALIAS_REF_EXPANSION = False 347 """ 348 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 349 350 For example: 351 WITH data AS ( 352 SELECT 353 1 AS id, 354 2 AS my_id 355 ) 356 SELECT 357 id AS my_id 358 FROM 359 data 360 WHERE 361 my_id = 1 362 GROUP BY 363 my_id, 364 HAVING 365 my_id = 1 366 367 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 368 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 369 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 370 - Clickhouse, which will forward the alias across the query i.e it resolves 371 to "WHERE id = 1 GROUP BY id HAVING id = 1" 372 """ 373 374 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 375 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 376 377 SUPPORTS_ORDER_BY_ALL = False 378 """ 379 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 380 """ 381 382 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 383 """ 384 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 385 as the former is of type INT[] vs the latter which is SUPER 386 """ 387 388 SUPPORTS_FIXED_SIZE_ARRAYS = False 389 """ 390 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 391 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 392 be interpreted as a subscript/index operator. 393 """ 394 395 STRICT_JSON_PATH_SYNTAX = True 396 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 397 398 ON_CONDITION_EMPTY_BEFORE_ERROR = True 399 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 400 401 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 402 """Whether ArrayAgg needs to filter NULL values.""" 403 404 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 405 exp.Except: True, 406 exp.Intersect: True, 407 exp.Union: True, 408 } 409 """ 410 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 411 must be explicitly specified. 412 """ 413 414 CREATABLE_KIND_MAPPING: dict[str, str] = {} 415 """ 416 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 417 equivalent of CREATE SCHEMA is CREATE DATABASE. 418 """ 419 420 # --- Autofilled --- 421 422 tokenizer_class = Tokenizer 423 jsonpath_tokenizer_class = JSONPathTokenizer 424 parser_class = Parser 425 generator_class = Generator 426 427 # A trie of the time_mapping keys 428 TIME_TRIE: t.Dict = {} 429 FORMAT_TRIE: t.Dict = {} 430 431 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 432 INVERSE_TIME_TRIE: t.Dict = {} 433 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 434 INVERSE_FORMAT_TRIE: t.Dict = {} 435 436 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 437 438 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 439 440 # Delimiters for string literals and identifiers 441 QUOTE_START = "'" 442 QUOTE_END = "'" 443 IDENTIFIER_START = '"' 444 IDENTIFIER_END = '"' 445 446 # Delimiters for bit, hex, byte and unicode literals 447 BIT_START: t.Optional[str] = None 448 BIT_END: t.Optional[str] = None 449 HEX_START: t.Optional[str] = None 450 HEX_END: t.Optional[str] = None 451 BYTE_START: t.Optional[str] = None 452 BYTE_END: t.Optional[str] = None 453 UNICODE_START: t.Optional[str] = None 454 UNICODE_END: t.Optional[str] = None 455 456 DATE_PART_MAPPING = { 457 "Y": "YEAR", 458 "YY": "YEAR", 459 "YYY": "YEAR", 460 "YYYY": "YEAR", 461 "YR": "YEAR", 462 "YEARS": "YEAR", 463 "YRS": "YEAR", 464 "MM": "MONTH", 465 "MON": "MONTH", 466 "MONS": "MONTH", 467 "MONTHS": "MONTH", 468 "D": "DAY", 469 "DD": "DAY", 470 "DAYS": "DAY", 471 "DAYOFMONTH": "DAY", 472 "DAY OF WEEK": "DAYOFWEEK", 473 "WEEKDAY": "DAYOFWEEK", 474 "DOW": "DAYOFWEEK", 475 "DW": "DAYOFWEEK", 476 "WEEKDAY_ISO": "DAYOFWEEKISO", 477 "DOW_ISO": "DAYOFWEEKISO", 478 "DW_ISO": "DAYOFWEEKISO", 479 "DAY OF YEAR": "DAYOFYEAR", 480 "DOY": "DAYOFYEAR", 481 "DY": "DAYOFYEAR", 482 "W": "WEEK", 483 "WK": "WEEK", 484 "WEEKOFYEAR": "WEEK", 485 "WOY": "WEEK", 486 "WY": "WEEK", 487 "WEEK_ISO": "WEEKISO", 488 "WEEKOFYEARISO": "WEEKISO", 489 "WEEKOFYEAR_ISO": "WEEKISO", 490 "Q": "QUARTER", 491 "QTR": "QUARTER", 492 "QTRS": "QUARTER", 493 "QUARTERS": "QUARTER", 494 "H": "HOUR", 495 "HH": "HOUR", 496 "HR": "HOUR", 497 "HOURS": "HOUR", 498 "HRS": "HOUR", 499 "M": "MINUTE", 500 "MI": "MINUTE", 501 "MIN": "MINUTE", 502 "MINUTES": "MINUTE", 503 "MINS": "MINUTE", 504 "S": "SECOND", 505 "SEC": "SECOND", 506 "SECONDS": "SECOND", 507 "SECS": "SECOND", 508 "MS": "MILLISECOND", 509 "MSEC": "MILLISECOND", 510 "MSECS": "MILLISECOND", 511 "MSECOND": "MILLISECOND", 512 "MSECONDS": "MILLISECOND", 513 "MILLISEC": "MILLISECOND", 514 "MILLISECS": "MILLISECOND", 515 "MILLISECON": "MILLISECOND", 516 "MILLISECONDS": "MILLISECOND", 517 "US": "MICROSECOND", 518 "USEC": "MICROSECOND", 519 "USECS": "MICROSECOND", 520 "MICROSEC": "MICROSECOND", 521 "MICROSECS": "MICROSECOND", 522 "USECOND": "MICROSECOND", 523 "USECONDS": "MICROSECOND", 524 "MICROSECONDS": "MICROSECOND", 525 "NS": "NANOSECOND", 526 "NSEC": "NANOSECOND", 527 "NANOSEC": "NANOSECOND", 528 "NSECOND": "NANOSECOND", 529 "NSECONDS": "NANOSECOND", 530 "NANOSECS": "NANOSECOND", 531 "EPOCH_SECOND": "EPOCH", 532 "EPOCH_SECONDS": "EPOCH", 533 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 534 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 535 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 536 "TZH": "TIMEZONE_HOUR", 537 "TZM": "TIMEZONE_MINUTE", 538 "DEC": "DECADE", 539 "DECS": "DECADE", 540 "DECADES": "DECADE", 541 "MIL": "MILLENIUM", 542 "MILS": "MILLENIUM", 543 "MILLENIA": "MILLENIUM", 544 "C": "CENTURY", 545 "CENT": "CENTURY", 546 "CENTS": "CENTURY", 547 "CENTURIES": "CENTURY", 548 } 549 550 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 551 exp.DataType.Type.BIGINT: { 552 exp.ApproxDistinct, 553 exp.ArraySize, 554 exp.Length, 555 }, 556 exp.DataType.Type.BOOLEAN: { 557 exp.Between, 558 exp.Boolean, 559 exp.In, 560 exp.RegexpLike, 561 }, 562 exp.DataType.Type.DATE: { 563 exp.CurrentDate, 564 exp.Date, 565 exp.DateFromParts, 566 exp.DateStrToDate, 567 exp.DiToDate, 568 exp.StrToDate, 569 exp.TimeStrToDate, 570 exp.TsOrDsToDate, 571 }, 572 exp.DataType.Type.DATETIME: { 573 exp.CurrentDatetime, 574 exp.Datetime, 575 exp.DatetimeAdd, 576 exp.DatetimeSub, 577 }, 578 exp.DataType.Type.DOUBLE: { 579 exp.ApproxQuantile, 580 exp.Avg, 581 exp.Div, 582 exp.Exp, 583 exp.Ln, 584 exp.Log, 585 exp.Pow, 586 exp.Quantile, 587 exp.Round, 588 exp.SafeDivide, 589 exp.Sqrt, 590 exp.Stddev, 591 exp.StddevPop, 592 exp.StddevSamp, 593 exp.Variance, 594 exp.VariancePop, 595 }, 596 exp.DataType.Type.INT: { 597 exp.Ceil, 598 exp.DatetimeDiff, 599 exp.DateDiff, 600 exp.TimestampDiff, 601 exp.TimeDiff, 602 exp.DateToDi, 603 exp.Levenshtein, 604 exp.Sign, 605 exp.StrPosition, 606 exp.TsOrDiToDi, 607 }, 608 exp.DataType.Type.JSON: { 609 exp.ParseJSON, 610 }, 611 exp.DataType.Type.TIME: { 612 exp.Time, 613 }, 614 exp.DataType.Type.TIMESTAMP: { 615 exp.CurrentTime, 616 exp.CurrentTimestamp, 617 exp.StrToTime, 618 exp.TimeAdd, 619 exp.TimeStrToTime, 620 exp.TimeSub, 621 exp.TimestampAdd, 622 exp.TimestampSub, 623 exp.UnixToTime, 624 }, 625 exp.DataType.Type.TINYINT: { 626 exp.Day, 627 exp.Month, 628 exp.Week, 629 exp.Year, 630 exp.Quarter, 631 }, 632 exp.DataType.Type.VARCHAR: { 633 exp.ArrayConcat, 634 exp.Concat, 635 exp.ConcatWs, 636 exp.DateToDateStr, 637 exp.GroupConcat, 638 exp.Initcap, 639 exp.Lower, 640 exp.Substring, 641 exp.TimeToStr, 642 exp.TimeToTimeStr, 643 exp.Trim, 644 exp.TsOrDsToDateStr, 645 exp.UnixToStr, 646 exp.UnixToTimeStr, 647 exp.Upper, 648 }, 649 } 650 651 ANNOTATORS: AnnotatorsType = { 652 **{ 653 expr_type: lambda self, e: self._annotate_unary(e) 654 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 655 }, 656 **{ 657 expr_type: lambda self, e: self._annotate_binary(e) 658 for expr_type in subclasses(exp.__name__, exp.Binary) 659 }, 660 **{ 661 expr_type: _annotate_with_type_lambda(data_type) 662 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 663 for expr_type in expressions 664 }, 665 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 666 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 667 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 668 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 669 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 670 exp.Bracket: lambda self, e: self._annotate_bracket(e), 671 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 672 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 673 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 674 exp.Count: lambda self, e: self._annotate_with_type( 675 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 676 ), 677 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 678 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 679 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 680 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 681 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 682 exp.Div: lambda self, e: self._annotate_div(e), 683 exp.Dot: lambda self, e: self._annotate_dot(e), 684 exp.Explode: lambda self, e: self._annotate_explode(e), 685 exp.Extract: lambda self, e: self._annotate_extract(e), 686 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 687 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 688 e, exp.DataType.build("ARRAY<DATE>") 689 ), 690 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 691 e, exp.DataType.build("ARRAY<TIMESTAMP>") 692 ), 693 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 694 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 695 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 696 exp.Literal: lambda self, e: self._annotate_literal(e), 697 exp.Map: lambda self, e: self._annotate_map(e), 698 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 699 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 700 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 701 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 702 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 703 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 704 exp.Struct: lambda self, e: self._annotate_struct(e), 705 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 706 exp.Timestamp: lambda self, e: self._annotate_with_type( 707 e, 708 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 709 ), 710 exp.ToMap: lambda self, e: self._annotate_to_map(e), 711 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 712 exp.Unnest: lambda self, e: self._annotate_unnest(e), 713 exp.VarMap: lambda self, e: self._annotate_map(e), 714 } 715 716 @classmethod 717 def get_or_raise(cls, dialect: DialectType) -> Dialect: 718 """ 719 Look up a dialect in the global dialect registry and return it if it exists. 720 721 Args: 722 dialect: The target dialect. If this is a string, it can be optionally followed by 723 additional key-value pairs that are separated by commas and are used to specify 724 dialect settings, such as whether the dialect's identifiers are case-sensitive. 725 726 Example: 727 >>> dialect = dialect_class = get_or_raise("duckdb") 728 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 729 730 Returns: 731 The corresponding Dialect instance. 732 """ 733 734 if not dialect: 735 return cls() 736 if isinstance(dialect, _Dialect): 737 return dialect() 738 if isinstance(dialect, Dialect): 739 return dialect 740 if isinstance(dialect, str): 741 try: 742 dialect_name, *kv_strings = dialect.split(",") 743 kv_pairs = (kv.split("=") for kv in kv_strings) 744 kwargs = {} 745 for pair in kv_pairs: 746 key = pair[0].strip() 747 value: t.Union[bool | str | None] = None 748 749 if len(pair) == 1: 750 # Default initialize standalone settings to True 751 value = True 752 elif len(pair) == 2: 753 value = pair[1].strip() 754 755 # Coerce the value to boolean if it matches to the truthy/falsy values below 756 value_lower = value.lower() 757 if value_lower in ("true", "1"): 758 value = True 759 elif value_lower in ("false", "0"): 760 value = False 761 762 kwargs[key] = value 763 764 except ValueError: 765 raise ValueError( 766 f"Invalid dialect format: '{dialect}'. " 767 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 768 ) 769 770 result = cls.get(dialect_name.strip()) 771 if not result: 772 from difflib import get_close_matches 773 774 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 775 if similar: 776 similar = f" Did you mean {similar}?" 777 778 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 779 780 return result(**kwargs) 781 782 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 783 784 @classmethod 785 def format_time( 786 cls, expression: t.Optional[str | exp.Expression] 787 ) -> t.Optional[exp.Expression]: 788 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 789 if isinstance(expression, str): 790 return exp.Literal.string( 791 # the time formats are quoted 792 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 793 ) 794 795 if expression and expression.is_string: 796 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 797 798 return expression 799 800 def __init__(self, **kwargs) -> None: 801 normalization_strategy = kwargs.pop("normalization_strategy", None) 802 803 if normalization_strategy is None: 804 self.normalization_strategy = self.NORMALIZATION_STRATEGY 805 else: 806 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 807 808 self.settings = kwargs 809 810 def __eq__(self, other: t.Any) -> bool: 811 # Does not currently take dialect state into account 812 return type(self) == other 813 814 def __hash__(self) -> int: 815 # Does not currently take dialect state into account 816 return hash(type(self)) 817 818 def normalize_identifier(self, expression: E) -> E: 819 """ 820 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 821 822 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 823 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 824 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 825 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 826 827 There are also dialects like Spark, which are case-insensitive even when quotes are 828 present, and dialects like MySQL, whose resolution rules match those employed by the 829 underlying operating system, for example they may always be case-sensitive in Linux. 830 831 Finally, the normalization behavior of some engines can even be controlled through flags, 832 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 833 834 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 835 that it can analyze queries in the optimizer and successfully capture their semantics. 836 """ 837 if ( 838 isinstance(expression, exp.Identifier) 839 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 840 and ( 841 not expression.quoted 842 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 843 ) 844 ): 845 expression.set( 846 "this", 847 ( 848 expression.this.upper() 849 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 850 else expression.this.lower() 851 ), 852 ) 853 854 return expression 855 856 def case_sensitive(self, text: str) -> bool: 857 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 858 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 859 return False 860 861 unsafe = ( 862 str.islower 863 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 864 else str.isupper 865 ) 866 return any(unsafe(char) for char in text) 867 868 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 869 """Checks if text can be identified given an identify option. 870 871 Args: 872 text: The text to check. 873 identify: 874 `"always"` or `True`: Always returns `True`. 875 `"safe"`: Only returns `True` if the identifier is case-insensitive. 876 877 Returns: 878 Whether the given text can be identified. 879 """ 880 if identify is True or identify == "always": 881 return True 882 883 if identify == "safe": 884 return not self.case_sensitive(text) 885 886 return False 887 888 def quote_identifier(self, expression: E, identify: bool = True) -> E: 889 """ 890 Adds quotes to a given identifier. 891 892 Args: 893 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 894 identify: If set to `False`, the quotes will only be added if the identifier is deemed 895 "unsafe", with respect to its characters and this dialect's normalization strategy. 896 """ 897 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 898 name = expression.this 899 expression.set( 900 "quoted", 901 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 902 ) 903 904 return expression 905 906 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 907 if isinstance(path, exp.Literal): 908 path_text = path.name 909 if path.is_number: 910 path_text = f"[{path_text}]" 911 try: 912 return parse_json_path(path_text, self) 913 except ParseError as e: 914 if self.STRICT_JSON_PATH_SYNTAX: 915 logger.warning(f"Invalid JSON path syntax. {str(e)}") 916 917 return path 918 919 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 920 return self.parser(**opts).parse(self.tokenize(sql), sql) 921 922 def parse_into( 923 self, expression_type: exp.IntoType, sql: str, **opts 924 ) -> t.List[t.Optional[exp.Expression]]: 925 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 926 927 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 928 return self.generator(**opts).generate(expression, copy=copy) 929 930 def transpile(self, sql: str, **opts) -> t.List[str]: 931 return [ 932 self.generate(expression, copy=False, **opts) if expression else "" 933 for expression in self.parse(sql) 934 ] 935 936 def tokenize(self, sql: str) -> t.List[Token]: 937 return self.tokenizer.tokenize(sql) 938 939 @property 940 def tokenizer(self) -> Tokenizer: 941 return self.tokenizer_class(dialect=self) 942 943 @property 944 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 945 return self.jsonpath_tokenizer_class(dialect=self) 946 947 def parser(self, **opts) -> Parser: 948 return self.parser_class(dialect=self, **opts) 949 950 def generator(self, **opts) -> Generator: 951 return self.generator_class(dialect=self, **opts)
800 def __init__(self, **kwargs) -> None: 801 normalization_strategy = kwargs.pop("normalization_strategy", None) 802 803 if normalization_strategy is None: 804 self.normalization_strategy = self.NORMALIZATION_STRATEGY 805 else: 806 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 807 808 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG function.
Possible values: True, False, None (two arguments are not supported by LOG)
Default NULL ordering method to use if not explicitly set.
Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"
Whether the behavior of a / b depends on the types of a and b.
False means a / b is always float division.
True means a / b is integer division if both a and b are integers.
A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy').
If empty, the corresponding trie will be constructed off of TIME_MAPPING.
Mapping of an escaped sequence (\n) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT * queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
Whether a set operation uses DISTINCT by default. This is None when either DISTINCT or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
716 @classmethod 717 def get_or_raise(cls, dialect: DialectType) -> Dialect: 718 """ 719 Look up a dialect in the global dialect registry and return it if it exists. 720 721 Args: 722 dialect: The target dialect. If this is a string, it can be optionally followed by 723 additional key-value pairs that are separated by commas and are used to specify 724 dialect settings, such as whether the dialect's identifiers are case-sensitive. 725 726 Example: 727 >>> dialect = dialect_class = get_or_raise("duckdb") 728 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 729 730 Returns: 731 The corresponding Dialect instance. 732 """ 733 734 if not dialect: 735 return cls() 736 if isinstance(dialect, _Dialect): 737 return dialect() 738 if isinstance(dialect, Dialect): 739 return dialect 740 if isinstance(dialect, str): 741 try: 742 dialect_name, *kv_strings = dialect.split(",") 743 kv_pairs = (kv.split("=") for kv in kv_strings) 744 kwargs = {} 745 for pair in kv_pairs: 746 key = pair[0].strip() 747 value: t.Union[bool | str | None] = None 748 749 if len(pair) == 1: 750 # Default initialize standalone settings to True 751 value = True 752 elif len(pair) == 2: 753 value = pair[1].strip() 754 755 # Coerce the value to boolean if it matches to the truthy/falsy values below 756 value_lower = value.lower() 757 if value_lower in ("true", "1"): 758 value = True 759 elif value_lower in ("false", "0"): 760 value = False 761 762 kwargs[key] = value 763 764 except ValueError: 765 raise ValueError( 766 f"Invalid dialect format: '{dialect}'. " 767 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 768 ) 769 770 result = cls.get(dialect_name.strip()) 771 if not result: 772 from difflib import get_close_matches 773 774 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 775 if similar: 776 similar = f" Did you mean {similar}?" 777 778 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 779 780 return result(**kwargs) 781 782 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
784 @classmethod 785 def format_time( 786 cls, expression: t.Optional[str | exp.Expression] 787 ) -> t.Optional[exp.Expression]: 788 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 789 if isinstance(expression, str): 790 return exp.Literal.string( 791 # the time formats are quoted 792 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 793 ) 794 795 if expression and expression.is_string: 796 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 797 798 return expression
Converts a time format in this dialect to its equivalent Python strftime format.
818 def normalize_identifier(self, expression: E) -> E: 819 """ 820 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 821 822 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 823 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 824 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 825 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 826 827 There are also dialects like Spark, which are case-insensitive even when quotes are 828 present, and dialects like MySQL, whose resolution rules match those employed by the 829 underlying operating system, for example they may always be case-sensitive in Linux. 830 831 Finally, the normalization behavior of some engines can even be controlled through flags, 832 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 833 834 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 835 that it can analyze queries in the optimizer and successfully capture their semantics. 836 """ 837 if ( 838 isinstance(expression, exp.Identifier) 839 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 840 and ( 841 not expression.quoted 842 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 843 ) 844 ): 845 expression.set( 846 "this", 847 ( 848 expression.this.upper() 849 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 850 else expression.this.lower() 851 ), 852 ) 853 854 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO would be resolved as foo in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
856 def case_sensitive(self, text: str) -> bool: 857 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 858 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 859 return False 860 861 unsafe = ( 862 str.islower 863 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 864 else str.isupper 865 ) 866 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
868 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 869 """Checks if text can be identified given an identify option. 870 871 Args: 872 text: The text to check. 873 identify: 874 `"always"` or `True`: Always returns `True`. 875 `"safe"`: Only returns `True` if the identifier is case-insensitive. 876 877 Returns: 878 Whether the given text can be identified. 879 """ 880 if identify is True or identify == "always": 881 return True 882 883 if identify == "safe": 884 return not self.case_sensitive(text) 885 886 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"orTrue: Always returnsTrue."safe": Only returnsTrueif the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
888 def quote_identifier(self, expression: E, identify: bool = True) -> E: 889 """ 890 Adds quotes to a given identifier. 891 892 Args: 893 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 894 identify: If set to `False`, the quotes will only be added if the identifier is deemed 895 "unsafe", with respect to its characters and this dialect's normalization strategy. 896 """ 897 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 898 name = expression.this 899 expression.set( 900 "quoted", 901 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 902 ) 903 904 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier, this method is a no-op. - identify: If set to
False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
906 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 907 if isinstance(path, exp.Literal): 908 path_text = path.name 909 if path.is_number: 910 path_text = f"[{path_text}]" 911 try: 912 return parse_json_path(path_text, self) 913 except ParseError as e: 914 if self.STRICT_JSON_PATH_SYNTAX: 915 logger.warning(f"Invalid JSON path syntax. {str(e)}") 916 917 return path
967def if_sql( 968 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 969) -> t.Callable[[Generator, exp.If], str]: 970 def _if_sql(self: Generator, expression: exp.If) -> str: 971 return self.func( 972 name, 973 expression.this, 974 expression.args.get("true"), 975 expression.args.get("false") or false_value, 976 ) 977 978 return _if_sql
981def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 982 this = expression.this 983 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 984 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 985 986 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1056def str_position_sql( 1057 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1058) -> str: 1059 this = self.sql(expression, "this") 1060 substr = self.sql(expression, "substr") 1061 position = self.sql(expression, "position") 1062 instance = expression.args.get("instance") if generate_instance else None 1063 position_offset = "" 1064 1065 if position: 1066 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1067 this = self.func("SUBSTR", this, position) 1068 position_offset = f" + {position} - 1" 1069 1070 return self.func("STRPOS", this, substr, instance) + position_offset
1079def var_map_sql( 1080 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1081) -> str: 1082 keys = expression.args["keys"] 1083 values = expression.args["values"] 1084 1085 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1086 self.unsupported("Cannot convert array columns into map.") 1087 return self.func(map_func_name, keys, values) 1088 1089 args = [] 1090 for key, value in zip(keys.expressions, values.expressions): 1091 args.append(self.sql(key)) 1092 args.append(self.sql(value)) 1093 1094 return self.func(map_func_name, *args)
1097def build_formatted_time( 1098 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1099) -> t.Callable[[t.List], E]: 1100 """Helper used for time expressions. 1101 1102 Args: 1103 exp_class: the expression class to instantiate. 1104 dialect: target sql dialect. 1105 default: the default format, True being time. 1106 1107 Returns: 1108 A callable that can be used to return the appropriately formatted time expression. 1109 """ 1110 1111 def _builder(args: t.List): 1112 return exp_class( 1113 this=seq_get(args, 0), 1114 format=Dialect[dialect].format_time( 1115 seq_get(args, 1) 1116 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1117 ), 1118 ) 1119 1120 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1123def time_format( 1124 dialect: DialectType = None, 1125) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1126 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1127 """ 1128 Returns the time format for a given expression, unless it's equivalent 1129 to the default time format of the dialect of interest. 1130 """ 1131 time_format = self.format_time(expression) 1132 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1133 1134 return _time_format
1137def build_date_delta( 1138 exp_class: t.Type[E], 1139 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1140 default_unit: t.Optional[str] = "DAY", 1141) -> t.Callable[[t.List], E]: 1142 def _builder(args: t.List) -> E: 1143 unit_based = len(args) == 3 1144 this = args[2] if unit_based else seq_get(args, 0) 1145 unit = None 1146 if unit_based or default_unit: 1147 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1148 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1149 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1150 1151 return _builder
1154def build_date_delta_with_interval( 1155 expression_class: t.Type[E], 1156) -> t.Callable[[t.List], t.Optional[E]]: 1157 def _builder(args: t.List) -> t.Optional[E]: 1158 if len(args) < 2: 1159 return None 1160 1161 interval = args[1] 1162 1163 if not isinstance(interval, exp.Interval): 1164 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1165 1166 expression = interval.this 1167 if expression and expression.is_string: 1168 expression = exp.Literal.number(expression.this) 1169 1170 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1171 1172 return _builder
1175def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1176 unit = seq_get(args, 0) 1177 this = seq_get(args, 1) 1178 1179 if isinstance(this, exp.Cast) and this.is_type("date"): 1180 return exp.DateTrunc(unit=unit, this=this) 1181 return exp.TimestampTrunc(this=this, unit=unit)
1184def date_add_interval_sql( 1185 data_type: str, kind: str 1186) -> t.Callable[[Generator, exp.Expression], str]: 1187 def func(self: Generator, expression: exp.Expression) -> str: 1188 this = self.sql(expression, "this") 1189 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1190 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1191 1192 return func
1195def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1196 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1197 args = [unit_to_str(expression), expression.this] 1198 if zone: 1199 args.append(expression.args.get("zone")) 1200 return self.func("DATE_TRUNC", *args) 1201 1202 return _timestamptrunc_sql
1205def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1206 zone = expression.args.get("zone") 1207 if not zone: 1208 from sqlglot.optimizer.annotate_types import annotate_types 1209 1210 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1211 return self.sql(exp.cast(expression.this, target_type)) 1212 if zone.name.lower() in TIMEZONES: 1213 return self.sql( 1214 exp.AtTimeZone( 1215 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1216 zone=zone, 1217 ) 1218 ) 1219 return self.func("TIMESTAMP", expression.this, zone)
1222def no_time_sql(self: Generator, expression: exp.Time) -> str: 1223 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1224 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1225 expr = exp.cast( 1226 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1227 ) 1228 return self.sql(expr)
1231def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1232 this = expression.this 1233 expr = expression.expression 1234 1235 if expr.name.lower() in TIMEZONES: 1236 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1237 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1238 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1239 return self.sql(this) 1240 1241 this = exp.cast(this, exp.DataType.Type.DATE) 1242 expr = exp.cast(expr, exp.DataType.Type.TIME) 1243 1244 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1276def timestrtotime_sql( 1277 self: Generator, 1278 expression: exp.TimeStrToTime, 1279 include_precision: bool = False, 1280) -> str: 1281 datatype = exp.DataType.build( 1282 exp.DataType.Type.TIMESTAMPTZ 1283 if expression.args.get("zone") 1284 else exp.DataType.Type.TIMESTAMP 1285 ) 1286 1287 if isinstance(expression.this, exp.Literal) and include_precision: 1288 precision = subsecond_precision(expression.this.name) 1289 if precision > 0: 1290 datatype = exp.DataType.build( 1291 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1292 ) 1293 1294 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1302def encode_decode_sql( 1303 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1304) -> str: 1305 charset = expression.args.get("charset") 1306 if charset and charset.name.lower() != "utf-8": 1307 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1308 1309 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1322def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1323 cond = expression.this 1324 1325 if isinstance(expression.this, exp.Distinct): 1326 cond = expression.this.expressions[0] 1327 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1328 1329 return self.func("sum", exp.func("if", cond, 1, 0))
1332def trim_sql(self: Generator, expression: exp.Trim) -> str: 1333 target = self.sql(expression, "this") 1334 trim_type = self.sql(expression, "position") 1335 remove_chars = self.sql(expression, "expression") 1336 collation = self.sql(expression, "collation") 1337 1338 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1339 if not remove_chars: 1340 return self.trim_sql(expression) 1341 1342 trim_type = f"{trim_type} " if trim_type else "" 1343 remove_chars = f"{remove_chars} " if remove_chars else "" 1344 from_part = "FROM " if trim_type or remove_chars else "" 1345 collation = f" COLLATE {collation}" if collation else "" 1346 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1367def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1368 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1369 if bad_args: 1370 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1371 1372 return self.func( 1373 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1374 )
1377def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1378 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1379 if bad_args: 1380 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1381 1382 return self.func( 1383 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1384 )
1387def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1388 names = [] 1389 for agg in aggregations: 1390 if isinstance(agg, exp.Alias): 1391 names.append(agg.alias) 1392 else: 1393 """ 1394 This case corresponds to aggregations without aliases being used as suffixes 1395 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1396 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1397 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1398 """ 1399 agg_all_unquoted = agg.transform( 1400 lambda node: ( 1401 exp.Identifier(this=node.name, quoted=False) 1402 if isinstance(node, exp.Identifier) 1403 else node 1404 ) 1405 ) 1406 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1407 1408 return names
1448def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1449 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1450 if expression.args.get("count"): 1451 self.unsupported(f"Only two arguments are supported in function {name}.") 1452 1453 return self.func(name, expression.this, expression.expression) 1454 1455 return _arg_max_or_min_sql
1458def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1459 this = expression.this.copy() 1460 1461 return_type = expression.return_type 1462 if return_type.is_type(exp.DataType.Type.DATE): 1463 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1464 # can truncate timestamp strings, because some dialects can't cast them to DATE 1465 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1466 1467 expression.this.replace(exp.cast(this, return_type)) 1468 return expression
1471def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1472 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1473 if cast and isinstance(expression, exp.TsOrDsAdd): 1474 expression = ts_or_ds_add_cast(expression) 1475 1476 return self.func( 1477 name, 1478 unit_to_var(expression), 1479 expression.expression, 1480 expression.this, 1481 ) 1482 1483 return _delta_sql
1486def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1487 unit = expression.args.get("unit") 1488 1489 if isinstance(unit, exp.Placeholder): 1490 return unit 1491 if unit: 1492 return exp.Literal.string(unit.name) 1493 return exp.Literal.string(default) if default else None
1523def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1524 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1525 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1526 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1527 1528 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1531def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1532 """Remove table refs from columns in when statements.""" 1533 alias = expression.this.args.get("alias") 1534 1535 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1536 return self.dialect.normalize_identifier(identifier).name if identifier else None 1537 1538 targets = {normalize(expression.this.this)} 1539 1540 if alias: 1541 targets.add(normalize(alias.this)) 1542 1543 for when in expression.expressions: 1544 # only remove the target names from the THEN clause 1545 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1546 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1547 then = when.args.get("then") 1548 if then: 1549 then.transform( 1550 lambda node: ( 1551 exp.column(node.this) 1552 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1553 else node 1554 ), 1555 copy=False, 1556 ) 1557 1558 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1561def build_json_extract_path( 1562 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1563) -> t.Callable[[t.List], F]: 1564 def _builder(args: t.List) -> F: 1565 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1566 for arg in args[1:]: 1567 if not isinstance(arg, exp.Literal): 1568 # We use the fallback parser because we can't really transpile non-literals safely 1569 return expr_type.from_arg_list(args) 1570 1571 text = arg.name 1572 if is_int(text): 1573 index = int(text) 1574 segments.append( 1575 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1576 ) 1577 else: 1578 segments.append(exp.JSONPathKey(this=text)) 1579 1580 # This is done to avoid failing in the expression validator due to the arg count 1581 del args[2:] 1582 return expr_type( 1583 this=seq_get(args, 0), 1584 expression=exp.JSONPath(expressions=segments), 1585 only_json_types=arrow_req_json_type, 1586 ) 1587 1588 return _builder
1591def json_extract_segments( 1592 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1593) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1594 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1595 path = expression.expression 1596 if not isinstance(path, exp.JSONPath): 1597 return rename_func(name)(self, expression) 1598 1599 segments = [] 1600 for segment in path.expressions: 1601 path = self.sql(segment) 1602 if path: 1603 if isinstance(segment, exp.JSONPathPart) and ( 1604 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1605 ): 1606 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1607 1608 segments.append(path) 1609 1610 if op: 1611 return f" {op} ".join([self.sql(expression.this), *segments]) 1612 return self.func(name, expression.this, *segments) 1613 1614 return _json_extract_segments
1624def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1625 cond = expression.expression 1626 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1627 alias = cond.expressions[0] 1628 cond = cond.this 1629 elif isinstance(cond, exp.Predicate): 1630 alias = "_u" 1631 else: 1632 self.unsupported("Unsupported filter condition") 1633 return "" 1634 1635 unnest = exp.Unnest(expressions=[expression.this]) 1636 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1637 return self.sql(exp.Array(expressions=[filtered]))
1649def build_default_decimal_type( 1650 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1651) -> t.Callable[[exp.DataType], exp.DataType]: 1652 def _builder(dtype: exp.DataType) -> exp.DataType: 1653 if dtype.expressions or precision is None: 1654 return dtype 1655 1656 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1657 return exp.DataType.build(f"DECIMAL({params})") 1658 1659 return _builder
1662def build_timestamp_from_parts(args: t.List) -> exp.Func: 1663 if len(args) == 2: 1664 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1665 # so we parse this into Anonymous for now instead of introducing complexity 1666 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1667 1668 return exp.TimestampFromParts.from_arg_list(args)
1675def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1676 start = expression.args.get("start") 1677 end = expression.args.get("end") 1678 step = expression.args.get("step") 1679 1680 if isinstance(start, exp.Cast): 1681 target_type = start.to 1682 elif isinstance(end, exp.Cast): 1683 target_type = end.to 1684 else: 1685 target_type = None 1686 1687 if start and end and target_type and target_type.is_type("date", "timestamp"): 1688 if isinstance(start, exp.Cast) and target_type is start.to: 1689 end = exp.cast(end, target_type) 1690 else: 1691 start = exp.cast(start, target_type) 1692 1693 return self.func("SEQUENCE", start, end, step)