sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 base = seq_get(bases, 0) 137 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 138 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 139 base_parser = (getattr(base, "parser_class", Parser),) 140 base_generator = (getattr(base, "generator_class", Generator),) 141 142 klass.tokenizer_class = klass.__dict__.get( 143 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 144 ) 145 klass.jsonpath_tokenizer_class = klass.__dict__.get( 146 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 147 ) 148 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 149 klass.generator_class = klass.__dict__.get( 150 "Generator", type("Generator", base_generator, {}) 151 ) 152 153 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 154 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 155 klass.tokenizer_class._IDENTIFIERS.items() 156 )[0] 157 158 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 159 return next( 160 ( 161 (s, e) 162 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 163 if t == token_type 164 ), 165 (None, None), 166 ) 167 168 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 169 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 170 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 171 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 172 173 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 174 klass.UNESCAPED_SEQUENCES = { 175 **UNESCAPED_SEQUENCES, 176 **klass.UNESCAPED_SEQUENCES, 177 } 178 179 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 180 181 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 182 183 if enum not in ("", "bigquery"): 184 klass.generator_class.SELECT_KINDS = () 185 186 if enum not in ("", "clickhouse"): 187 klass.generator_class.SUPPORTS_NULLABLE_TYPES = False 188 189 if enum not in ("", "athena", "presto", "trino"): 190 klass.generator_class.TRY_SUPPORTED = False 191 klass.generator_class.SUPPORTS_UESCAPE = False 192 193 if enum not in ("", "databricks", "hive", "spark", "spark2"): 194 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 195 for modifier in ("cluster", "distribute", "sort"): 196 modifier_transforms.pop(modifier, None) 197 198 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 199 200 if enum not in ("", "doris", "mysql"): 201 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 202 TokenType.STRAIGHT_JOIN, 203 } 204 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 205 TokenType.STRAIGHT_JOIN, 206 } 207 208 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 209 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 210 TokenType.ANTI, 211 TokenType.SEMI, 212 } 213 214 return klass 215 216 217class Dialect(metaclass=_Dialect): 218 INDEX_OFFSET = 0 219 """The base index offset for arrays.""" 220 221 WEEK_OFFSET = 0 222 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 223 224 UNNEST_COLUMN_ONLY = False 225 """Whether `UNNEST` table aliases are treated as column aliases.""" 226 227 ALIAS_POST_TABLESAMPLE = False 228 """Whether the table alias comes after tablesample.""" 229 230 TABLESAMPLE_SIZE_IS_PERCENT = False 231 """Whether a size in the table sample clause represents percentage.""" 232 233 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 234 """Specifies the strategy according to which identifiers should be normalized.""" 235 236 IDENTIFIERS_CAN_START_WITH_DIGIT = False 237 """Whether an unquoted identifier can start with a digit.""" 238 239 DPIPE_IS_STRING_CONCAT = True 240 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 241 242 STRICT_STRING_CONCAT = False 243 """Whether `CONCAT`'s arguments must be strings.""" 244 245 SUPPORTS_USER_DEFINED_TYPES = True 246 """Whether user-defined data types are supported.""" 247 248 SUPPORTS_SEMI_ANTI_JOIN = True 249 """Whether `SEMI` or `ANTI` joins are supported.""" 250 251 SUPPORTS_COLUMN_JOIN_MARKS = False 252 """Whether the old-style outer join (+) syntax is supported.""" 253 254 COPY_PARAMS_ARE_CSV = True 255 """Separator of COPY statement parameters.""" 256 257 NORMALIZE_FUNCTIONS: bool | str = "upper" 258 """ 259 Determines how function names are going to be normalized. 260 Possible values: 261 "upper" or True: Convert names to uppercase. 262 "lower": Convert names to lowercase. 263 False: Disables function name normalization. 264 """ 265 266 LOG_BASE_FIRST: t.Optional[bool] = True 267 """ 268 Whether the base comes first in the `LOG` function. 269 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 270 """ 271 272 NULL_ORDERING = "nulls_are_small" 273 """ 274 Default `NULL` ordering method to use if not explicitly set. 275 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 276 """ 277 278 TYPED_DIVISION = False 279 """ 280 Whether the behavior of `a / b` depends on the types of `a` and `b`. 281 False means `a / b` is always float division. 282 True means `a / b` is integer division if both `a` and `b` are integers. 283 """ 284 285 SAFE_DIVISION = False 286 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 287 288 CONCAT_COALESCE = False 289 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 290 291 HEX_LOWERCASE = False 292 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 293 294 DATE_FORMAT = "'%Y-%m-%d'" 295 DATEINT_FORMAT = "'%Y%m%d'" 296 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 297 298 TIME_MAPPING: t.Dict[str, str] = {} 299 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 300 301 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 302 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 303 FORMAT_MAPPING: t.Dict[str, str] = {} 304 """ 305 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 306 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 307 """ 308 309 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 310 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 311 312 PSEUDOCOLUMNS: t.Set[str] = set() 313 """ 314 Columns that are auto-generated by the engine corresponding to this dialect. 315 For example, such columns may be excluded from `SELECT *` queries. 316 """ 317 318 PREFER_CTE_ALIAS_COLUMN = False 319 """ 320 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 321 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 322 any projection aliases in the subquery. 323 324 For example, 325 WITH y(c) AS ( 326 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 327 ) SELECT c FROM y; 328 329 will be rewritten as 330 331 WITH y(c) AS ( 332 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 333 ) SELECT c FROM y; 334 """ 335 336 COPY_PARAMS_ARE_CSV = True 337 """ 338 Whether COPY statement parameters are separated by comma or whitespace 339 """ 340 341 FORCE_EARLY_ALIAS_REF_EXPANSION = False 342 """ 343 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 344 345 For example: 346 WITH data AS ( 347 SELECT 348 1 AS id, 349 2 AS my_id 350 ) 351 SELECT 352 id AS my_id 353 FROM 354 data 355 WHERE 356 my_id = 1 357 GROUP BY 358 my_id, 359 HAVING 360 my_id = 1 361 362 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 363 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 364 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 365 """ 366 367 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 368 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 369 370 SUPPORTS_ORDER_BY_ALL = False 371 """ 372 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 373 """ 374 375 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 376 """ 377 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 378 as the former is of type INT[] vs the latter which is SUPER 379 """ 380 381 SUPPORTS_FIXED_SIZE_ARRAYS = False 382 """ 383 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 384 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 385 """ 386 387 # --- Autofilled --- 388 389 tokenizer_class = Tokenizer 390 jsonpath_tokenizer_class = JSONPathTokenizer 391 parser_class = Parser 392 generator_class = Generator 393 394 # A trie of the time_mapping keys 395 TIME_TRIE: t.Dict = {} 396 FORMAT_TRIE: t.Dict = {} 397 398 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 399 INVERSE_TIME_TRIE: t.Dict = {} 400 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 401 INVERSE_FORMAT_TRIE: t.Dict = {} 402 403 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 404 405 # Delimiters for string literals and identifiers 406 QUOTE_START = "'" 407 QUOTE_END = "'" 408 IDENTIFIER_START = '"' 409 IDENTIFIER_END = '"' 410 411 # Delimiters for bit, hex, byte and unicode literals 412 BIT_START: t.Optional[str] = None 413 BIT_END: t.Optional[str] = None 414 HEX_START: t.Optional[str] = None 415 HEX_END: t.Optional[str] = None 416 BYTE_START: t.Optional[str] = None 417 BYTE_END: t.Optional[str] = None 418 UNICODE_START: t.Optional[str] = None 419 UNICODE_END: t.Optional[str] = None 420 421 DATE_PART_MAPPING = { 422 "Y": "YEAR", 423 "YY": "YEAR", 424 "YYY": "YEAR", 425 "YYYY": "YEAR", 426 "YR": "YEAR", 427 "YEARS": "YEAR", 428 "YRS": "YEAR", 429 "MM": "MONTH", 430 "MON": "MONTH", 431 "MONS": "MONTH", 432 "MONTHS": "MONTH", 433 "D": "DAY", 434 "DD": "DAY", 435 "DAYS": "DAY", 436 "DAYOFMONTH": "DAY", 437 "DAY OF WEEK": "DAYOFWEEK", 438 "WEEKDAY": "DAYOFWEEK", 439 "DOW": "DAYOFWEEK", 440 "DW": "DAYOFWEEK", 441 "WEEKDAY_ISO": "DAYOFWEEKISO", 442 "DOW_ISO": "DAYOFWEEKISO", 443 "DW_ISO": "DAYOFWEEKISO", 444 "DAY OF YEAR": "DAYOFYEAR", 445 "DOY": "DAYOFYEAR", 446 "DY": "DAYOFYEAR", 447 "W": "WEEK", 448 "WK": "WEEK", 449 "WEEKOFYEAR": "WEEK", 450 "WOY": "WEEK", 451 "WY": "WEEK", 452 "WEEK_ISO": "WEEKISO", 453 "WEEKOFYEARISO": "WEEKISO", 454 "WEEKOFYEAR_ISO": "WEEKISO", 455 "Q": "QUARTER", 456 "QTR": "QUARTER", 457 "QTRS": "QUARTER", 458 "QUARTERS": "QUARTER", 459 "H": "HOUR", 460 "HH": "HOUR", 461 "HR": "HOUR", 462 "HOURS": "HOUR", 463 "HRS": "HOUR", 464 "M": "MINUTE", 465 "MI": "MINUTE", 466 "MIN": "MINUTE", 467 "MINUTES": "MINUTE", 468 "MINS": "MINUTE", 469 "S": "SECOND", 470 "SEC": "SECOND", 471 "SECONDS": "SECOND", 472 "SECS": "SECOND", 473 "MS": "MILLISECOND", 474 "MSEC": "MILLISECOND", 475 "MSECS": "MILLISECOND", 476 "MSECOND": "MILLISECOND", 477 "MSECONDS": "MILLISECOND", 478 "MILLISEC": "MILLISECOND", 479 "MILLISECS": "MILLISECOND", 480 "MILLISECON": "MILLISECOND", 481 "MILLISECONDS": "MILLISECOND", 482 "US": "MICROSECOND", 483 "USEC": "MICROSECOND", 484 "USECS": "MICROSECOND", 485 "MICROSEC": "MICROSECOND", 486 "MICROSECS": "MICROSECOND", 487 "USECOND": "MICROSECOND", 488 "USECONDS": "MICROSECOND", 489 "MICROSECONDS": "MICROSECOND", 490 "NS": "NANOSECOND", 491 "NSEC": "NANOSECOND", 492 "NANOSEC": "NANOSECOND", 493 "NSECOND": "NANOSECOND", 494 "NSECONDS": "NANOSECOND", 495 "NANOSECS": "NANOSECOND", 496 "EPOCH_SECOND": "EPOCH", 497 "EPOCH_SECONDS": "EPOCH", 498 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 499 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 500 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 501 "TZH": "TIMEZONE_HOUR", 502 "TZM": "TIMEZONE_MINUTE", 503 "DEC": "DECADE", 504 "DECS": "DECADE", 505 "DECADES": "DECADE", 506 "MIL": "MILLENIUM", 507 "MILS": "MILLENIUM", 508 "MILLENIA": "MILLENIUM", 509 "C": "CENTURY", 510 "CENT": "CENTURY", 511 "CENTS": "CENTURY", 512 "CENTURIES": "CENTURY", 513 } 514 515 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 516 exp.DataType.Type.BIGINT: { 517 exp.ApproxDistinct, 518 exp.ArraySize, 519 exp.Count, 520 exp.Length, 521 }, 522 exp.DataType.Type.BOOLEAN: { 523 exp.Between, 524 exp.Boolean, 525 exp.In, 526 exp.RegexpLike, 527 }, 528 exp.DataType.Type.DATE: { 529 exp.CurrentDate, 530 exp.Date, 531 exp.DateFromParts, 532 exp.DateStrToDate, 533 exp.DiToDate, 534 exp.StrToDate, 535 exp.TimeStrToDate, 536 exp.TsOrDsToDate, 537 }, 538 exp.DataType.Type.DATETIME: { 539 exp.CurrentDatetime, 540 exp.Datetime, 541 exp.DatetimeAdd, 542 exp.DatetimeSub, 543 }, 544 exp.DataType.Type.DOUBLE: { 545 exp.ApproxQuantile, 546 exp.Avg, 547 exp.Div, 548 exp.Exp, 549 exp.Ln, 550 exp.Log, 551 exp.Pow, 552 exp.Quantile, 553 exp.Round, 554 exp.SafeDivide, 555 exp.Sqrt, 556 exp.Stddev, 557 exp.StddevPop, 558 exp.StddevSamp, 559 exp.Variance, 560 exp.VariancePop, 561 }, 562 exp.DataType.Type.INT: { 563 exp.Ceil, 564 exp.DatetimeDiff, 565 exp.DateDiff, 566 exp.TimestampDiff, 567 exp.TimeDiff, 568 exp.DateToDi, 569 exp.Levenshtein, 570 exp.Sign, 571 exp.StrPosition, 572 exp.TsOrDiToDi, 573 }, 574 exp.DataType.Type.JSON: { 575 exp.ParseJSON, 576 }, 577 exp.DataType.Type.TIME: { 578 exp.Time, 579 }, 580 exp.DataType.Type.TIMESTAMP: { 581 exp.CurrentTime, 582 exp.CurrentTimestamp, 583 exp.StrToTime, 584 exp.TimeAdd, 585 exp.TimeStrToTime, 586 exp.TimeSub, 587 exp.TimestampAdd, 588 exp.TimestampSub, 589 exp.UnixToTime, 590 }, 591 exp.DataType.Type.TINYINT: { 592 exp.Day, 593 exp.Month, 594 exp.Week, 595 exp.Year, 596 exp.Quarter, 597 }, 598 exp.DataType.Type.VARCHAR: { 599 exp.ArrayConcat, 600 exp.Concat, 601 exp.ConcatWs, 602 exp.DateToDateStr, 603 exp.GroupConcat, 604 exp.Initcap, 605 exp.Lower, 606 exp.Substring, 607 exp.TimeToStr, 608 exp.TimeToTimeStr, 609 exp.Trim, 610 exp.TsOrDsToDateStr, 611 exp.UnixToStr, 612 exp.UnixToTimeStr, 613 exp.Upper, 614 }, 615 } 616 617 ANNOTATORS: AnnotatorsType = { 618 **{ 619 expr_type: lambda self, e: self._annotate_unary(e) 620 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 621 }, 622 **{ 623 expr_type: lambda self, e: self._annotate_binary(e) 624 for expr_type in subclasses(exp.__name__, exp.Binary) 625 }, 626 **{ 627 expr_type: _annotate_with_type_lambda(data_type) 628 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 629 for expr_type in expressions 630 }, 631 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 632 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 633 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 634 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 635 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 636 exp.Bracket: lambda self, e: self._annotate_bracket(e), 637 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 638 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 639 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 640 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 641 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 642 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 643 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 644 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 645 exp.Div: lambda self, e: self._annotate_div(e), 646 exp.Dot: lambda self, e: self._annotate_dot(e), 647 exp.Explode: lambda self, e: self._annotate_explode(e), 648 exp.Extract: lambda self, e: self._annotate_extract(e), 649 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 650 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 651 e, exp.DataType.build("ARRAY<DATE>") 652 ), 653 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 654 e, exp.DataType.build("ARRAY<TIMESTAMP>") 655 ), 656 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 657 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 658 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 659 exp.Literal: lambda self, e: self._annotate_literal(e), 660 exp.Map: lambda self, e: self._annotate_map(e), 661 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 662 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 663 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 664 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 665 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 666 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 667 exp.Struct: lambda self, e: self._annotate_struct(e), 668 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 669 exp.Timestamp: lambda self, e: self._annotate_with_type( 670 e, 671 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 672 ), 673 exp.ToMap: lambda self, e: self._annotate_to_map(e), 674 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 675 exp.Unnest: lambda self, e: self._annotate_unnest(e), 676 exp.VarMap: lambda self, e: self._annotate_map(e), 677 } 678 679 @classmethod 680 def get_or_raise(cls, dialect: DialectType) -> Dialect: 681 """ 682 Look up a dialect in the global dialect registry and return it if it exists. 683 684 Args: 685 dialect: The target dialect. If this is a string, it can be optionally followed by 686 additional key-value pairs that are separated by commas and are used to specify 687 dialect settings, such as whether the dialect's identifiers are case-sensitive. 688 689 Example: 690 >>> dialect = dialect_class = get_or_raise("duckdb") 691 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 692 693 Returns: 694 The corresponding Dialect instance. 695 """ 696 697 if not dialect: 698 return cls() 699 if isinstance(dialect, _Dialect): 700 return dialect() 701 if isinstance(dialect, Dialect): 702 return dialect 703 if isinstance(dialect, str): 704 try: 705 dialect_name, *kv_strings = dialect.split(",") 706 kv_pairs = (kv.split("=") for kv in kv_strings) 707 kwargs = {} 708 for pair in kv_pairs: 709 key = pair[0].strip() 710 value: t.Union[bool | str | None] = None 711 712 if len(pair) == 1: 713 # Default initialize standalone settings to True 714 value = True 715 elif len(pair) == 2: 716 value = pair[1].strip() 717 718 # Coerce the value to boolean if it matches to the truthy/falsy values below 719 value_lower = value.lower() 720 if value_lower in ("true", "1"): 721 value = True 722 elif value_lower in ("false", "0"): 723 value = False 724 725 kwargs[key] = value 726 727 except ValueError: 728 raise ValueError( 729 f"Invalid dialect format: '{dialect}'. " 730 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 731 ) 732 733 result = cls.get(dialect_name.strip()) 734 if not result: 735 from difflib import get_close_matches 736 737 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 738 if similar: 739 similar = f" Did you mean {similar}?" 740 741 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 742 743 return result(**kwargs) 744 745 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 746 747 @classmethod 748 def format_time( 749 cls, expression: t.Optional[str | exp.Expression] 750 ) -> t.Optional[exp.Expression]: 751 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 752 if isinstance(expression, str): 753 return exp.Literal.string( 754 # the time formats are quoted 755 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 756 ) 757 758 if expression and expression.is_string: 759 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 760 761 return expression 762 763 def __init__(self, **kwargs) -> None: 764 normalization_strategy = kwargs.pop("normalization_strategy", None) 765 766 if normalization_strategy is None: 767 self.normalization_strategy = self.NORMALIZATION_STRATEGY 768 else: 769 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 770 771 self.settings = kwargs 772 773 def __eq__(self, other: t.Any) -> bool: 774 # Does not currently take dialect state into account 775 return type(self) == other 776 777 def __hash__(self) -> int: 778 # Does not currently take dialect state into account 779 return hash(type(self)) 780 781 def normalize_identifier(self, expression: E) -> E: 782 """ 783 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 784 785 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 786 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 787 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 788 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 789 790 There are also dialects like Spark, which are case-insensitive even when quotes are 791 present, and dialects like MySQL, whose resolution rules match those employed by the 792 underlying operating system, for example they may always be case-sensitive in Linux. 793 794 Finally, the normalization behavior of some engines can even be controlled through flags, 795 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 796 797 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 798 that it can analyze queries in the optimizer and successfully capture their semantics. 799 """ 800 if ( 801 isinstance(expression, exp.Identifier) 802 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 803 and ( 804 not expression.quoted 805 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 806 ) 807 ): 808 expression.set( 809 "this", 810 ( 811 expression.this.upper() 812 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 813 else expression.this.lower() 814 ), 815 ) 816 817 return expression 818 819 def case_sensitive(self, text: str) -> bool: 820 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 821 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 822 return False 823 824 unsafe = ( 825 str.islower 826 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 827 else str.isupper 828 ) 829 return any(unsafe(char) for char in text) 830 831 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 832 """Checks if text can be identified given an identify option. 833 834 Args: 835 text: The text to check. 836 identify: 837 `"always"` or `True`: Always returns `True`. 838 `"safe"`: Only returns `True` if the identifier is case-insensitive. 839 840 Returns: 841 Whether the given text can be identified. 842 """ 843 if identify is True or identify == "always": 844 return True 845 846 if identify == "safe": 847 return not self.case_sensitive(text) 848 849 return False 850 851 def quote_identifier(self, expression: E, identify: bool = True) -> E: 852 """ 853 Adds quotes to a given identifier. 854 855 Args: 856 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 857 identify: If set to `False`, the quotes will only be added if the identifier is deemed 858 "unsafe", with respect to its characters and this dialect's normalization strategy. 859 """ 860 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 861 name = expression.this 862 expression.set( 863 "quoted", 864 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 865 ) 866 867 return expression 868 869 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 870 if isinstance(path, exp.Literal): 871 path_text = path.name 872 if path.is_number: 873 path_text = f"[{path_text}]" 874 try: 875 return parse_json_path(path_text, self) 876 except ParseError as e: 877 logger.warning(f"Invalid JSON path syntax. {str(e)}") 878 879 return path 880 881 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 882 return self.parser(**opts).parse(self.tokenize(sql), sql) 883 884 def parse_into( 885 self, expression_type: exp.IntoType, sql: str, **opts 886 ) -> t.List[t.Optional[exp.Expression]]: 887 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 888 889 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 890 return self.generator(**opts).generate(expression, copy=copy) 891 892 def transpile(self, sql: str, **opts) -> t.List[str]: 893 return [ 894 self.generate(expression, copy=False, **opts) if expression else "" 895 for expression in self.parse(sql) 896 ] 897 898 def tokenize(self, sql: str) -> t.List[Token]: 899 return self.tokenizer.tokenize(sql) 900 901 @property 902 def tokenizer(self) -> Tokenizer: 903 return self.tokenizer_class(dialect=self) 904 905 @property 906 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 907 return self.jsonpath_tokenizer_class(dialect=self) 908 909 def parser(self, **opts) -> Parser: 910 return self.parser_class(dialect=self, **opts) 911 912 def generator(self, **opts) -> Generator: 913 return self.generator_class(dialect=self, **opts) 914 915 916DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 917 918 919def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 920 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 921 922 923def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 924 if expression.args.get("accuracy"): 925 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 926 return self.func("APPROX_COUNT_DISTINCT", expression.this) 927 928 929def if_sql( 930 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 931) -> t.Callable[[Generator, exp.If], str]: 932 def _if_sql(self: Generator, expression: exp.If) -> str: 933 return self.func( 934 name, 935 expression.this, 936 expression.args.get("true"), 937 expression.args.get("false") or false_value, 938 ) 939 940 return _if_sql 941 942 943def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 944 this = expression.this 945 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 946 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 947 948 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 949 950 951def inline_array_sql(self: Generator, expression: exp.Array) -> str: 952 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 953 954 955def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 956 elem = seq_get(expression.expressions, 0) 957 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 958 return self.func("ARRAY", elem) 959 return inline_array_sql(self, expression) 960 961 962def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 963 return self.like_sql( 964 exp.Like( 965 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 966 ) 967 ) 968 969 970def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 971 zone = self.sql(expression, "this") 972 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 973 974 975def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 976 if expression.args.get("recursive"): 977 self.unsupported("Recursive CTEs are unsupported") 978 expression.args["recursive"] = False 979 return self.with_sql(expression) 980 981 982def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 983 n = self.sql(expression, "this") 984 d = self.sql(expression, "expression") 985 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 986 987 988def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 989 self.unsupported("TABLESAMPLE unsupported") 990 return self.sql(expression.this) 991 992 993def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 994 self.unsupported("PIVOT unsupported") 995 return "" 996 997 998def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 999 return self.cast_sql(expression) 1000 1001 1002def no_comment_column_constraint_sql( 1003 self: Generator, expression: exp.CommentColumnConstraint 1004) -> str: 1005 self.unsupported("CommentColumnConstraint unsupported") 1006 return "" 1007 1008 1009def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1010 self.unsupported("MAP_FROM_ENTRIES unsupported") 1011 return "" 1012 1013 1014def str_position_sql( 1015 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1016) -> str: 1017 this = self.sql(expression, "this") 1018 substr = self.sql(expression, "substr") 1019 position = self.sql(expression, "position") 1020 instance = expression.args.get("instance") if generate_instance else None 1021 position_offset = "" 1022 1023 if position: 1024 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1025 this = self.func("SUBSTR", this, position) 1026 position_offset = f" + {position} - 1" 1027 1028 return self.func("STRPOS", this, substr, instance) + position_offset 1029 1030 1031def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1032 return ( 1033 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1034 ) 1035 1036 1037def var_map_sql( 1038 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1039) -> str: 1040 keys = expression.args["keys"] 1041 values = expression.args["values"] 1042 1043 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1044 self.unsupported("Cannot convert array columns into map.") 1045 return self.func(map_func_name, keys, values) 1046 1047 args = [] 1048 for key, value in zip(keys.expressions, values.expressions): 1049 args.append(self.sql(key)) 1050 args.append(self.sql(value)) 1051 1052 return self.func(map_func_name, *args) 1053 1054 1055def build_formatted_time( 1056 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1057) -> t.Callable[[t.List], E]: 1058 """Helper used for time expressions. 1059 1060 Args: 1061 exp_class: the expression class to instantiate. 1062 dialect: target sql dialect. 1063 default: the default format, True being time. 1064 1065 Returns: 1066 A callable that can be used to return the appropriately formatted time expression. 1067 """ 1068 1069 def _builder(args: t.List): 1070 return exp_class( 1071 this=seq_get(args, 0), 1072 format=Dialect[dialect].format_time( 1073 seq_get(args, 1) 1074 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1075 ), 1076 ) 1077 1078 return _builder 1079 1080 1081def time_format( 1082 dialect: DialectType = None, 1083) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1084 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1085 """ 1086 Returns the time format for a given expression, unless it's equivalent 1087 to the default time format of the dialect of interest. 1088 """ 1089 time_format = self.format_time(expression) 1090 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1091 1092 return _time_format 1093 1094 1095def build_date_delta( 1096 exp_class: t.Type[E], 1097 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1098 default_unit: t.Optional[str] = "DAY", 1099) -> t.Callable[[t.List], E]: 1100 def _builder(args: t.List) -> E: 1101 unit_based = len(args) == 3 1102 this = args[2] if unit_based else seq_get(args, 0) 1103 unit = None 1104 if unit_based or default_unit: 1105 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1106 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1107 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1108 1109 return _builder 1110 1111 1112def build_date_delta_with_interval( 1113 expression_class: t.Type[E], 1114) -> t.Callable[[t.List], t.Optional[E]]: 1115 def _builder(args: t.List) -> t.Optional[E]: 1116 if len(args) < 2: 1117 return None 1118 1119 interval = args[1] 1120 1121 if not isinstance(interval, exp.Interval): 1122 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1123 1124 expression = interval.this 1125 if expression and expression.is_string: 1126 expression = exp.Literal.number(expression.this) 1127 1128 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1129 1130 return _builder 1131 1132 1133def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1134 unit = seq_get(args, 0) 1135 this = seq_get(args, 1) 1136 1137 if isinstance(this, exp.Cast) and this.is_type("date"): 1138 return exp.DateTrunc(unit=unit, this=this) 1139 return exp.TimestampTrunc(this=this, unit=unit) 1140 1141 1142def date_add_interval_sql( 1143 data_type: str, kind: str 1144) -> t.Callable[[Generator, exp.Expression], str]: 1145 def func(self: Generator, expression: exp.Expression) -> str: 1146 this = self.sql(expression, "this") 1147 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1148 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1149 1150 return func 1151 1152 1153def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1154 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1155 args = [unit_to_str(expression), expression.this] 1156 if zone: 1157 args.append(expression.args.get("zone")) 1158 return self.func("DATE_TRUNC", *args) 1159 1160 return _timestamptrunc_sql 1161 1162 1163def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1164 zone = expression.args.get("zone") 1165 if not zone: 1166 from sqlglot.optimizer.annotate_types import annotate_types 1167 1168 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1169 return self.sql(exp.cast(expression.this, target_type)) 1170 if zone.name.lower() in TIMEZONES: 1171 return self.sql( 1172 exp.AtTimeZone( 1173 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1174 zone=zone, 1175 ) 1176 ) 1177 return self.func("TIMESTAMP", expression.this, zone) 1178 1179 1180def no_time_sql(self: Generator, expression: exp.Time) -> str: 1181 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1182 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1183 expr = exp.cast( 1184 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1185 ) 1186 return self.sql(expr) 1187 1188 1189def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1190 this = expression.this 1191 expr = expression.expression 1192 1193 if expr.name.lower() in TIMEZONES: 1194 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1195 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1196 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1197 return self.sql(this) 1198 1199 this = exp.cast(this, exp.DataType.Type.DATE) 1200 expr = exp.cast(expr, exp.DataType.Type.TIME) 1201 1202 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1203 1204 1205def locate_to_strposition(args: t.List) -> exp.Expression: 1206 return exp.StrPosition( 1207 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1208 ) 1209 1210 1211def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1212 return self.func( 1213 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1214 ) 1215 1216 1217def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1218 return self.sql( 1219 exp.Substring( 1220 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1221 ) 1222 ) 1223 1224 1225def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1226 return self.sql( 1227 exp.Substring( 1228 this=expression.this, 1229 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1230 ) 1231 ) 1232 1233 1234def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1235 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1236 1237 1238def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1239 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1240 1241 1242# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1243def encode_decode_sql( 1244 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1245) -> str: 1246 charset = expression.args.get("charset") 1247 if charset and charset.name.lower() != "utf-8": 1248 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1249 1250 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1251 1252 1253def min_or_least(self: Generator, expression: exp.Min) -> str: 1254 name = "LEAST" if expression.expressions else "MIN" 1255 return rename_func(name)(self, expression) 1256 1257 1258def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1259 name = "GREATEST" if expression.expressions else "MAX" 1260 return rename_func(name)(self, expression) 1261 1262 1263def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1264 cond = expression.this 1265 1266 if isinstance(expression.this, exp.Distinct): 1267 cond = expression.this.expressions[0] 1268 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1269 1270 return self.func("sum", exp.func("if", cond, 1, 0)) 1271 1272 1273def trim_sql(self: Generator, expression: exp.Trim) -> str: 1274 target = self.sql(expression, "this") 1275 trim_type = self.sql(expression, "position") 1276 remove_chars = self.sql(expression, "expression") 1277 collation = self.sql(expression, "collation") 1278 1279 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1280 if not remove_chars and not collation: 1281 return self.trim_sql(expression) 1282 1283 trim_type = f"{trim_type} " if trim_type else "" 1284 remove_chars = f"{remove_chars} " if remove_chars else "" 1285 from_part = "FROM " if trim_type or remove_chars else "" 1286 collation = f" COLLATE {collation}" if collation else "" 1287 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1288 1289 1290def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1291 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1292 1293 1294def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1295 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1296 1297 1298def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1299 delim, *rest_args = expression.expressions 1300 return self.sql( 1301 reduce( 1302 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1303 rest_args, 1304 ) 1305 ) 1306 1307 1308def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1309 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1310 if bad_args: 1311 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1312 1313 return self.func( 1314 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1315 ) 1316 1317 1318def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1319 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1320 if bad_args: 1321 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1322 1323 return self.func( 1324 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1325 ) 1326 1327 1328def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1329 names = [] 1330 for agg in aggregations: 1331 if isinstance(agg, exp.Alias): 1332 names.append(agg.alias) 1333 else: 1334 """ 1335 This case corresponds to aggregations without aliases being used as suffixes 1336 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1337 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1338 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1339 """ 1340 agg_all_unquoted = agg.transform( 1341 lambda node: ( 1342 exp.Identifier(this=node.name, quoted=False) 1343 if isinstance(node, exp.Identifier) 1344 else node 1345 ) 1346 ) 1347 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1348 1349 return names 1350 1351 1352def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1353 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1354 1355 1356# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1357def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1358 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1359 1360 1361def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1362 return self.func("MAX", expression.this) 1363 1364 1365def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1366 a = self.sql(expression.left) 1367 b = self.sql(expression.right) 1368 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1369 1370 1371def is_parse_json(expression: exp.Expression) -> bool: 1372 return isinstance(expression, exp.ParseJSON) or ( 1373 isinstance(expression, exp.Cast) and expression.is_type("json") 1374 ) 1375 1376 1377def isnull_to_is_null(args: t.List) -> exp.Expression: 1378 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1379 1380 1381def generatedasidentitycolumnconstraint_sql( 1382 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1383) -> str: 1384 start = self.sql(expression, "start") or "1" 1385 increment = self.sql(expression, "increment") or "1" 1386 return f"IDENTITY({start}, {increment})" 1387 1388 1389def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1390 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1391 if expression.args.get("count"): 1392 self.unsupported(f"Only two arguments are supported in function {name}.") 1393 1394 return self.func(name, expression.this, expression.expression) 1395 1396 return _arg_max_or_min_sql 1397 1398 1399def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1400 this = expression.this.copy() 1401 1402 return_type = expression.return_type 1403 if return_type.is_type(exp.DataType.Type.DATE): 1404 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1405 # can truncate timestamp strings, because some dialects can't cast them to DATE 1406 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1407 1408 expression.this.replace(exp.cast(this, return_type)) 1409 return expression 1410 1411 1412def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1413 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1414 if cast and isinstance(expression, exp.TsOrDsAdd): 1415 expression = ts_or_ds_add_cast(expression) 1416 1417 return self.func( 1418 name, 1419 unit_to_var(expression), 1420 expression.expression, 1421 expression.this, 1422 ) 1423 1424 return _delta_sql 1425 1426 1427def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1428 unit = expression.args.get("unit") 1429 1430 if isinstance(unit, exp.Placeholder): 1431 return unit 1432 if unit: 1433 return exp.Literal.string(unit.name) 1434 return exp.Literal.string(default) if default else None 1435 1436 1437def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1438 unit = expression.args.get("unit") 1439 1440 if isinstance(unit, (exp.Var, exp.Placeholder)): 1441 return unit 1442 return exp.Var(this=default) if default else None 1443 1444 1445@t.overload 1446def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1447 pass 1448 1449 1450@t.overload 1451def map_date_part( 1452 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1453) -> t.Optional[exp.Expression]: 1454 pass 1455 1456 1457def map_date_part(part, dialect: DialectType = Dialect): 1458 mapped = ( 1459 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1460 ) 1461 return exp.var(mapped) if mapped else part 1462 1463 1464def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1465 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1466 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1467 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1468 1469 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1470 1471 1472def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1473 """Remove table refs from columns in when statements.""" 1474 alias = expression.this.args.get("alias") 1475 1476 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1477 return self.dialect.normalize_identifier(identifier).name if identifier else None 1478 1479 targets = {normalize(expression.this.this)} 1480 1481 if alias: 1482 targets.add(normalize(alias.this)) 1483 1484 for when in expression.expressions: 1485 when.transform( 1486 lambda node: ( 1487 exp.column(node.this) 1488 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1489 else node 1490 ), 1491 copy=False, 1492 ) 1493 1494 return self.merge_sql(expression) 1495 1496 1497def build_json_extract_path( 1498 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1499) -> t.Callable[[t.List], F]: 1500 def _builder(args: t.List) -> F: 1501 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1502 for arg in args[1:]: 1503 if not isinstance(arg, exp.Literal): 1504 # We use the fallback parser because we can't really transpile non-literals safely 1505 return expr_type.from_arg_list(args) 1506 1507 text = arg.name 1508 if is_int(text): 1509 index = int(text) 1510 segments.append( 1511 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1512 ) 1513 else: 1514 segments.append(exp.JSONPathKey(this=text)) 1515 1516 # This is done to avoid failing in the expression validator due to the arg count 1517 del args[2:] 1518 return expr_type( 1519 this=seq_get(args, 0), 1520 expression=exp.JSONPath(expressions=segments), 1521 only_json_types=arrow_req_json_type, 1522 ) 1523 1524 return _builder 1525 1526 1527def json_extract_segments( 1528 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1529) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1530 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1531 path = expression.expression 1532 if not isinstance(path, exp.JSONPath): 1533 return rename_func(name)(self, expression) 1534 1535 segments = [] 1536 for segment in path.expressions: 1537 path = self.sql(segment) 1538 if path: 1539 if isinstance(segment, exp.JSONPathPart) and ( 1540 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1541 ): 1542 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1543 1544 segments.append(path) 1545 1546 if op: 1547 return f" {op} ".join([self.sql(expression.this), *segments]) 1548 return self.func(name, expression.this, *segments) 1549 1550 return _json_extract_segments 1551 1552 1553def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1554 if isinstance(expression.this, exp.JSONPathWildcard): 1555 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1556 1557 return expression.name 1558 1559 1560def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1561 cond = expression.expression 1562 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1563 alias = cond.expressions[0] 1564 cond = cond.this 1565 elif isinstance(cond, exp.Predicate): 1566 alias = "_u" 1567 else: 1568 self.unsupported("Unsupported filter condition") 1569 return "" 1570 1571 unnest = exp.Unnest(expressions=[expression.this]) 1572 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1573 return self.sql(exp.Array(expressions=[filtered])) 1574 1575 1576def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1577 return self.func( 1578 "TO_NUMBER", 1579 expression.this, 1580 expression.args.get("format"), 1581 expression.args.get("nlsparam"), 1582 ) 1583 1584 1585def build_default_decimal_type( 1586 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1587) -> t.Callable[[exp.DataType], exp.DataType]: 1588 def _builder(dtype: exp.DataType) -> exp.DataType: 1589 if dtype.expressions or precision is None: 1590 return dtype 1591 1592 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1593 return exp.DataType.build(f"DECIMAL({params})") 1594 1595 return _builder 1596 1597 1598def build_timestamp_from_parts(args: t.List) -> exp.Func: 1599 if len(args) == 2: 1600 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1601 # so we parse this into Anonymous for now instead of introducing complexity 1602 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1603 1604 return exp.TimestampFromParts.from_arg_list(args) 1605 1606 1607def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1608 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1609 1610 1611def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1612 start = expression.args.get("start") 1613 end = expression.args.get("end") 1614 step = expression.args.get("step") 1615 1616 if isinstance(start, exp.Cast): 1617 target_type = start.to 1618 elif isinstance(end, exp.Cast): 1619 target_type = end.to 1620 else: 1621 target_type = None 1622 1623 if start and end and target_type and target_type.is_type("date", "timestamp"): 1624 if isinstance(start, exp.Cast) and target_type is start.to: 1625 end = exp.cast(end, target_type) 1626 else: 1627 start = exp.cast(start, target_type) 1628 1629 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 LOG_BASE_FIRST: t.Optional[bool] = True 268 """ 269 Whether the base comes first in the `LOG` function. 270 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 271 """ 272 273 NULL_ORDERING = "nulls_are_small" 274 """ 275 Default `NULL` ordering method to use if not explicitly set. 276 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 277 """ 278 279 TYPED_DIVISION = False 280 """ 281 Whether the behavior of `a / b` depends on the types of `a` and `b`. 282 False means `a / b` is always float division. 283 True means `a / b` is integer division if both `a` and `b` are integers. 284 """ 285 286 SAFE_DIVISION = False 287 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 288 289 CONCAT_COALESCE = False 290 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 291 292 HEX_LOWERCASE = False 293 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 294 295 DATE_FORMAT = "'%Y-%m-%d'" 296 DATEINT_FORMAT = "'%Y%m%d'" 297 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 298 299 TIME_MAPPING: t.Dict[str, str] = {} 300 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 301 302 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 303 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 304 FORMAT_MAPPING: t.Dict[str, str] = {} 305 """ 306 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 307 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 308 """ 309 310 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 311 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 312 313 PSEUDOCOLUMNS: t.Set[str] = set() 314 """ 315 Columns that are auto-generated by the engine corresponding to this dialect. 316 For example, such columns may be excluded from `SELECT *` queries. 317 """ 318 319 PREFER_CTE_ALIAS_COLUMN = False 320 """ 321 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 322 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 323 any projection aliases in the subquery. 324 325 For example, 326 WITH y(c) AS ( 327 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 328 ) SELECT c FROM y; 329 330 will be rewritten as 331 332 WITH y(c) AS ( 333 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 334 ) SELECT c FROM y; 335 """ 336 337 COPY_PARAMS_ARE_CSV = True 338 """ 339 Whether COPY statement parameters are separated by comma or whitespace 340 """ 341 342 FORCE_EARLY_ALIAS_REF_EXPANSION = False 343 """ 344 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 345 346 For example: 347 WITH data AS ( 348 SELECT 349 1 AS id, 350 2 AS my_id 351 ) 352 SELECT 353 id AS my_id 354 FROM 355 data 356 WHERE 357 my_id = 1 358 GROUP BY 359 my_id, 360 HAVING 361 my_id = 1 362 363 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 364 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 365 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 366 """ 367 368 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 369 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 370 371 SUPPORTS_ORDER_BY_ALL = False 372 """ 373 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 374 """ 375 376 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 377 """ 378 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 379 as the former is of type INT[] vs the latter which is SUPER 380 """ 381 382 SUPPORTS_FIXED_SIZE_ARRAYS = False 383 """ 384 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 385 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 386 """ 387 388 # --- Autofilled --- 389 390 tokenizer_class = Tokenizer 391 jsonpath_tokenizer_class = JSONPathTokenizer 392 parser_class = Parser 393 generator_class = Generator 394 395 # A trie of the time_mapping keys 396 TIME_TRIE: t.Dict = {} 397 FORMAT_TRIE: t.Dict = {} 398 399 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 400 INVERSE_TIME_TRIE: t.Dict = {} 401 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 402 INVERSE_FORMAT_TRIE: t.Dict = {} 403 404 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 405 406 # Delimiters for string literals and identifiers 407 QUOTE_START = "'" 408 QUOTE_END = "'" 409 IDENTIFIER_START = '"' 410 IDENTIFIER_END = '"' 411 412 # Delimiters for bit, hex, byte and unicode literals 413 BIT_START: t.Optional[str] = None 414 BIT_END: t.Optional[str] = None 415 HEX_START: t.Optional[str] = None 416 HEX_END: t.Optional[str] = None 417 BYTE_START: t.Optional[str] = None 418 BYTE_END: t.Optional[str] = None 419 UNICODE_START: t.Optional[str] = None 420 UNICODE_END: t.Optional[str] = None 421 422 DATE_PART_MAPPING = { 423 "Y": "YEAR", 424 "YY": "YEAR", 425 "YYY": "YEAR", 426 "YYYY": "YEAR", 427 "YR": "YEAR", 428 "YEARS": "YEAR", 429 "YRS": "YEAR", 430 "MM": "MONTH", 431 "MON": "MONTH", 432 "MONS": "MONTH", 433 "MONTHS": "MONTH", 434 "D": "DAY", 435 "DD": "DAY", 436 "DAYS": "DAY", 437 "DAYOFMONTH": "DAY", 438 "DAY OF WEEK": "DAYOFWEEK", 439 "WEEKDAY": "DAYOFWEEK", 440 "DOW": "DAYOFWEEK", 441 "DW": "DAYOFWEEK", 442 "WEEKDAY_ISO": "DAYOFWEEKISO", 443 "DOW_ISO": "DAYOFWEEKISO", 444 "DW_ISO": "DAYOFWEEKISO", 445 "DAY OF YEAR": "DAYOFYEAR", 446 "DOY": "DAYOFYEAR", 447 "DY": "DAYOFYEAR", 448 "W": "WEEK", 449 "WK": "WEEK", 450 "WEEKOFYEAR": "WEEK", 451 "WOY": "WEEK", 452 "WY": "WEEK", 453 "WEEK_ISO": "WEEKISO", 454 "WEEKOFYEARISO": "WEEKISO", 455 "WEEKOFYEAR_ISO": "WEEKISO", 456 "Q": "QUARTER", 457 "QTR": "QUARTER", 458 "QTRS": "QUARTER", 459 "QUARTERS": "QUARTER", 460 "H": "HOUR", 461 "HH": "HOUR", 462 "HR": "HOUR", 463 "HOURS": "HOUR", 464 "HRS": "HOUR", 465 "M": "MINUTE", 466 "MI": "MINUTE", 467 "MIN": "MINUTE", 468 "MINUTES": "MINUTE", 469 "MINS": "MINUTE", 470 "S": "SECOND", 471 "SEC": "SECOND", 472 "SECONDS": "SECOND", 473 "SECS": "SECOND", 474 "MS": "MILLISECOND", 475 "MSEC": "MILLISECOND", 476 "MSECS": "MILLISECOND", 477 "MSECOND": "MILLISECOND", 478 "MSECONDS": "MILLISECOND", 479 "MILLISEC": "MILLISECOND", 480 "MILLISECS": "MILLISECOND", 481 "MILLISECON": "MILLISECOND", 482 "MILLISECONDS": "MILLISECOND", 483 "US": "MICROSECOND", 484 "USEC": "MICROSECOND", 485 "USECS": "MICROSECOND", 486 "MICROSEC": "MICROSECOND", 487 "MICROSECS": "MICROSECOND", 488 "USECOND": "MICROSECOND", 489 "USECONDS": "MICROSECOND", 490 "MICROSECONDS": "MICROSECOND", 491 "NS": "NANOSECOND", 492 "NSEC": "NANOSECOND", 493 "NANOSEC": "NANOSECOND", 494 "NSECOND": "NANOSECOND", 495 "NSECONDS": "NANOSECOND", 496 "NANOSECS": "NANOSECOND", 497 "EPOCH_SECOND": "EPOCH", 498 "EPOCH_SECONDS": "EPOCH", 499 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 500 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 501 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 502 "TZH": "TIMEZONE_HOUR", 503 "TZM": "TIMEZONE_MINUTE", 504 "DEC": "DECADE", 505 "DECS": "DECADE", 506 "DECADES": "DECADE", 507 "MIL": "MILLENIUM", 508 "MILS": "MILLENIUM", 509 "MILLENIA": "MILLENIUM", 510 "C": "CENTURY", 511 "CENT": "CENTURY", 512 "CENTS": "CENTURY", 513 "CENTURIES": "CENTURY", 514 } 515 516 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 517 exp.DataType.Type.BIGINT: { 518 exp.ApproxDistinct, 519 exp.ArraySize, 520 exp.Count, 521 exp.Length, 522 }, 523 exp.DataType.Type.BOOLEAN: { 524 exp.Between, 525 exp.Boolean, 526 exp.In, 527 exp.RegexpLike, 528 }, 529 exp.DataType.Type.DATE: { 530 exp.CurrentDate, 531 exp.Date, 532 exp.DateFromParts, 533 exp.DateStrToDate, 534 exp.DiToDate, 535 exp.StrToDate, 536 exp.TimeStrToDate, 537 exp.TsOrDsToDate, 538 }, 539 exp.DataType.Type.DATETIME: { 540 exp.CurrentDatetime, 541 exp.Datetime, 542 exp.DatetimeAdd, 543 exp.DatetimeSub, 544 }, 545 exp.DataType.Type.DOUBLE: { 546 exp.ApproxQuantile, 547 exp.Avg, 548 exp.Div, 549 exp.Exp, 550 exp.Ln, 551 exp.Log, 552 exp.Pow, 553 exp.Quantile, 554 exp.Round, 555 exp.SafeDivide, 556 exp.Sqrt, 557 exp.Stddev, 558 exp.StddevPop, 559 exp.StddevSamp, 560 exp.Variance, 561 exp.VariancePop, 562 }, 563 exp.DataType.Type.INT: { 564 exp.Ceil, 565 exp.DatetimeDiff, 566 exp.DateDiff, 567 exp.TimestampDiff, 568 exp.TimeDiff, 569 exp.DateToDi, 570 exp.Levenshtein, 571 exp.Sign, 572 exp.StrPosition, 573 exp.TsOrDiToDi, 574 }, 575 exp.DataType.Type.JSON: { 576 exp.ParseJSON, 577 }, 578 exp.DataType.Type.TIME: { 579 exp.Time, 580 }, 581 exp.DataType.Type.TIMESTAMP: { 582 exp.CurrentTime, 583 exp.CurrentTimestamp, 584 exp.StrToTime, 585 exp.TimeAdd, 586 exp.TimeStrToTime, 587 exp.TimeSub, 588 exp.TimestampAdd, 589 exp.TimestampSub, 590 exp.UnixToTime, 591 }, 592 exp.DataType.Type.TINYINT: { 593 exp.Day, 594 exp.Month, 595 exp.Week, 596 exp.Year, 597 exp.Quarter, 598 }, 599 exp.DataType.Type.VARCHAR: { 600 exp.ArrayConcat, 601 exp.Concat, 602 exp.ConcatWs, 603 exp.DateToDateStr, 604 exp.GroupConcat, 605 exp.Initcap, 606 exp.Lower, 607 exp.Substring, 608 exp.TimeToStr, 609 exp.TimeToTimeStr, 610 exp.Trim, 611 exp.TsOrDsToDateStr, 612 exp.UnixToStr, 613 exp.UnixToTimeStr, 614 exp.Upper, 615 }, 616 } 617 618 ANNOTATORS: AnnotatorsType = { 619 **{ 620 expr_type: lambda self, e: self._annotate_unary(e) 621 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 622 }, 623 **{ 624 expr_type: lambda self, e: self._annotate_binary(e) 625 for expr_type in subclasses(exp.__name__, exp.Binary) 626 }, 627 **{ 628 expr_type: _annotate_with_type_lambda(data_type) 629 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 630 for expr_type in expressions 631 }, 632 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 633 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 634 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 635 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 636 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 637 exp.Bracket: lambda self, e: self._annotate_bracket(e), 638 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 639 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 640 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 641 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 642 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 643 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 644 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 645 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 646 exp.Div: lambda self, e: self._annotate_div(e), 647 exp.Dot: lambda self, e: self._annotate_dot(e), 648 exp.Explode: lambda self, e: self._annotate_explode(e), 649 exp.Extract: lambda self, e: self._annotate_extract(e), 650 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 651 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 652 e, exp.DataType.build("ARRAY<DATE>") 653 ), 654 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 655 e, exp.DataType.build("ARRAY<TIMESTAMP>") 656 ), 657 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 658 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 659 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 660 exp.Literal: lambda self, e: self._annotate_literal(e), 661 exp.Map: lambda self, e: self._annotate_map(e), 662 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 663 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 664 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 665 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 666 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 667 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 668 exp.Struct: lambda self, e: self._annotate_struct(e), 669 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 670 exp.Timestamp: lambda self, e: self._annotate_with_type( 671 e, 672 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 673 ), 674 exp.ToMap: lambda self, e: self._annotate_to_map(e), 675 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 676 exp.Unnest: lambda self, e: self._annotate_unnest(e), 677 exp.VarMap: lambda self, e: self._annotate_map(e), 678 } 679 680 @classmethod 681 def get_or_raise(cls, dialect: DialectType) -> Dialect: 682 """ 683 Look up a dialect in the global dialect registry and return it if it exists. 684 685 Args: 686 dialect: The target dialect. If this is a string, it can be optionally followed by 687 additional key-value pairs that are separated by commas and are used to specify 688 dialect settings, such as whether the dialect's identifiers are case-sensitive. 689 690 Example: 691 >>> dialect = dialect_class = get_or_raise("duckdb") 692 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 693 694 Returns: 695 The corresponding Dialect instance. 696 """ 697 698 if not dialect: 699 return cls() 700 if isinstance(dialect, _Dialect): 701 return dialect() 702 if isinstance(dialect, Dialect): 703 return dialect 704 if isinstance(dialect, str): 705 try: 706 dialect_name, *kv_strings = dialect.split(",") 707 kv_pairs = (kv.split("=") for kv in kv_strings) 708 kwargs = {} 709 for pair in kv_pairs: 710 key = pair[0].strip() 711 value: t.Union[bool | str | None] = None 712 713 if len(pair) == 1: 714 # Default initialize standalone settings to True 715 value = True 716 elif len(pair) == 2: 717 value = pair[1].strip() 718 719 # Coerce the value to boolean if it matches to the truthy/falsy values below 720 value_lower = value.lower() 721 if value_lower in ("true", "1"): 722 value = True 723 elif value_lower in ("false", "0"): 724 value = False 725 726 kwargs[key] = value 727 728 except ValueError: 729 raise ValueError( 730 f"Invalid dialect format: '{dialect}'. " 731 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 732 ) 733 734 result = cls.get(dialect_name.strip()) 735 if not result: 736 from difflib import get_close_matches 737 738 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 739 if similar: 740 similar = f" Did you mean {similar}?" 741 742 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 743 744 return result(**kwargs) 745 746 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 747 748 @classmethod 749 def format_time( 750 cls, expression: t.Optional[str | exp.Expression] 751 ) -> t.Optional[exp.Expression]: 752 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 753 if isinstance(expression, str): 754 return exp.Literal.string( 755 # the time formats are quoted 756 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 757 ) 758 759 if expression and expression.is_string: 760 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 761 762 return expression 763 764 def __init__(self, **kwargs) -> None: 765 normalization_strategy = kwargs.pop("normalization_strategy", None) 766 767 if normalization_strategy is None: 768 self.normalization_strategy = self.NORMALIZATION_STRATEGY 769 else: 770 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 771 772 self.settings = kwargs 773 774 def __eq__(self, other: t.Any) -> bool: 775 # Does not currently take dialect state into account 776 return type(self) == other 777 778 def __hash__(self) -> int: 779 # Does not currently take dialect state into account 780 return hash(type(self)) 781 782 def normalize_identifier(self, expression: E) -> E: 783 """ 784 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 785 786 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 787 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 788 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 789 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 790 791 There are also dialects like Spark, which are case-insensitive even when quotes are 792 present, and dialects like MySQL, whose resolution rules match those employed by the 793 underlying operating system, for example they may always be case-sensitive in Linux. 794 795 Finally, the normalization behavior of some engines can even be controlled through flags, 796 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 797 798 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 799 that it can analyze queries in the optimizer and successfully capture their semantics. 800 """ 801 if ( 802 isinstance(expression, exp.Identifier) 803 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 804 and ( 805 not expression.quoted 806 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 807 ) 808 ): 809 expression.set( 810 "this", 811 ( 812 expression.this.upper() 813 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 814 else expression.this.lower() 815 ), 816 ) 817 818 return expression 819 820 def case_sensitive(self, text: str) -> bool: 821 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 822 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 823 return False 824 825 unsafe = ( 826 str.islower 827 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 828 else str.isupper 829 ) 830 return any(unsafe(char) for char in text) 831 832 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 833 """Checks if text can be identified given an identify option. 834 835 Args: 836 text: The text to check. 837 identify: 838 `"always"` or `True`: Always returns `True`. 839 `"safe"`: Only returns `True` if the identifier is case-insensitive. 840 841 Returns: 842 Whether the given text can be identified. 843 """ 844 if identify is True or identify == "always": 845 return True 846 847 if identify == "safe": 848 return not self.case_sensitive(text) 849 850 return False 851 852 def quote_identifier(self, expression: E, identify: bool = True) -> E: 853 """ 854 Adds quotes to a given identifier. 855 856 Args: 857 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 858 identify: If set to `False`, the quotes will only be added if the identifier is deemed 859 "unsafe", with respect to its characters and this dialect's normalization strategy. 860 """ 861 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 862 name = expression.this 863 expression.set( 864 "quoted", 865 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 866 ) 867 868 return expression 869 870 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 871 if isinstance(path, exp.Literal): 872 path_text = path.name 873 if path.is_number: 874 path_text = f"[{path_text}]" 875 try: 876 return parse_json_path(path_text, self) 877 except ParseError as e: 878 logger.warning(f"Invalid JSON path syntax. {str(e)}") 879 880 return path 881 882 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 883 return self.parser(**opts).parse(self.tokenize(sql), sql) 884 885 def parse_into( 886 self, expression_type: exp.IntoType, sql: str, **opts 887 ) -> t.List[t.Optional[exp.Expression]]: 888 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 889 890 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 891 return self.generator(**opts).generate(expression, copy=copy) 892 893 def transpile(self, sql: str, **opts) -> t.List[str]: 894 return [ 895 self.generate(expression, copy=False, **opts) if expression else "" 896 for expression in self.parse(sql) 897 ] 898 899 def tokenize(self, sql: str) -> t.List[Token]: 900 return self.tokenizer.tokenize(sql) 901 902 @property 903 def tokenizer(self) -> Tokenizer: 904 return self.tokenizer_class(dialect=self) 905 906 @property 907 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 908 return self.jsonpath_tokenizer_class(dialect=self) 909 910 def parser(self, **opts) -> Parser: 911 return self.parser_class(dialect=self, **opts) 912 913 def generator(self, **opts) -> Generator: 914 return self.generator_class(dialect=self, **opts)
764 def __init__(self, **kwargs) -> None: 765 normalization_strategy = kwargs.pop("normalization_strategy", None) 766 767 if normalization_strategy is None: 768 self.normalization_strategy = self.NORMALIZATION_STRATEGY 769 else: 770 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 771 772 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
680 @classmethod 681 def get_or_raise(cls, dialect: DialectType) -> Dialect: 682 """ 683 Look up a dialect in the global dialect registry and return it if it exists. 684 685 Args: 686 dialect: The target dialect. If this is a string, it can be optionally followed by 687 additional key-value pairs that are separated by commas and are used to specify 688 dialect settings, such as whether the dialect's identifiers are case-sensitive. 689 690 Example: 691 >>> dialect = dialect_class = get_or_raise("duckdb") 692 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 693 694 Returns: 695 The corresponding Dialect instance. 696 """ 697 698 if not dialect: 699 return cls() 700 if isinstance(dialect, _Dialect): 701 return dialect() 702 if isinstance(dialect, Dialect): 703 return dialect 704 if isinstance(dialect, str): 705 try: 706 dialect_name, *kv_strings = dialect.split(",") 707 kv_pairs = (kv.split("=") for kv in kv_strings) 708 kwargs = {} 709 for pair in kv_pairs: 710 key = pair[0].strip() 711 value: t.Union[bool | str | None] = None 712 713 if len(pair) == 1: 714 # Default initialize standalone settings to True 715 value = True 716 elif len(pair) == 2: 717 value = pair[1].strip() 718 719 # Coerce the value to boolean if it matches to the truthy/falsy values below 720 value_lower = value.lower() 721 if value_lower in ("true", "1"): 722 value = True 723 elif value_lower in ("false", "0"): 724 value = False 725 726 kwargs[key] = value 727 728 except ValueError: 729 raise ValueError( 730 f"Invalid dialect format: '{dialect}'. " 731 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 732 ) 733 734 result = cls.get(dialect_name.strip()) 735 if not result: 736 from difflib import get_close_matches 737 738 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 739 if similar: 740 similar = f" Did you mean {similar}?" 741 742 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 743 744 return result(**kwargs) 745 746 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
748 @classmethod 749 def format_time( 750 cls, expression: t.Optional[str | exp.Expression] 751 ) -> t.Optional[exp.Expression]: 752 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 753 if isinstance(expression, str): 754 return exp.Literal.string( 755 # the time formats are quoted 756 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 757 ) 758 759 if expression and expression.is_string: 760 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 761 762 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
782 def normalize_identifier(self, expression: E) -> E: 783 """ 784 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 785 786 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 787 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 788 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 789 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 790 791 There are also dialects like Spark, which are case-insensitive even when quotes are 792 present, and dialects like MySQL, whose resolution rules match those employed by the 793 underlying operating system, for example they may always be case-sensitive in Linux. 794 795 Finally, the normalization behavior of some engines can even be controlled through flags, 796 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 797 798 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 799 that it can analyze queries in the optimizer and successfully capture their semantics. 800 """ 801 if ( 802 isinstance(expression, exp.Identifier) 803 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 804 and ( 805 not expression.quoted 806 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 807 ) 808 ): 809 expression.set( 810 "this", 811 ( 812 expression.this.upper() 813 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 814 else expression.this.lower() 815 ), 816 ) 817 818 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
820 def case_sensitive(self, text: str) -> bool: 821 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 822 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 823 return False 824 825 unsafe = ( 826 str.islower 827 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 828 else str.isupper 829 ) 830 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
832 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 833 """Checks if text can be identified given an identify option. 834 835 Args: 836 text: The text to check. 837 identify: 838 `"always"` or `True`: Always returns `True`. 839 `"safe"`: Only returns `True` if the identifier is case-insensitive. 840 841 Returns: 842 Whether the given text can be identified. 843 """ 844 if identify is True or identify == "always": 845 return True 846 847 if identify == "safe": 848 return not self.case_sensitive(text) 849 850 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
852 def quote_identifier(self, expression: E, identify: bool = True) -> E: 853 """ 854 Adds quotes to a given identifier. 855 856 Args: 857 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 858 identify: If set to `False`, the quotes will only be added if the identifier is deemed 859 "unsafe", with respect to its characters and this dialect's normalization strategy. 860 """ 861 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 862 name = expression.this 863 expression.set( 864 "quoted", 865 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 866 ) 867 868 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
870 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 871 if isinstance(path, exp.Literal): 872 path_text = path.name 873 if path.is_number: 874 path_text = f"[{path_text}]" 875 try: 876 return parse_json_path(path_text, self) 877 except ParseError as e: 878 logger.warning(f"Invalid JSON path syntax. {str(e)}") 879 880 return path
930def if_sql( 931 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 932) -> t.Callable[[Generator, exp.If], str]: 933 def _if_sql(self: Generator, expression: exp.If) -> str: 934 return self.func( 935 name, 936 expression.this, 937 expression.args.get("true"), 938 expression.args.get("false") or false_value, 939 ) 940 941 return _if_sql
944def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 945 this = expression.this 946 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 947 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 948 949 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1015def str_position_sql( 1016 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1017) -> str: 1018 this = self.sql(expression, "this") 1019 substr = self.sql(expression, "substr") 1020 position = self.sql(expression, "position") 1021 instance = expression.args.get("instance") if generate_instance else None 1022 position_offset = "" 1023 1024 if position: 1025 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1026 this = self.func("SUBSTR", this, position) 1027 position_offset = f" + {position} - 1" 1028 1029 return self.func("STRPOS", this, substr, instance) + position_offset
1038def var_map_sql( 1039 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1040) -> str: 1041 keys = expression.args["keys"] 1042 values = expression.args["values"] 1043 1044 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1045 self.unsupported("Cannot convert array columns into map.") 1046 return self.func(map_func_name, keys, values) 1047 1048 args = [] 1049 for key, value in zip(keys.expressions, values.expressions): 1050 args.append(self.sql(key)) 1051 args.append(self.sql(value)) 1052 1053 return self.func(map_func_name, *args)
1056def build_formatted_time( 1057 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1058) -> t.Callable[[t.List], E]: 1059 """Helper used for time expressions. 1060 1061 Args: 1062 exp_class: the expression class to instantiate. 1063 dialect: target sql dialect. 1064 default: the default format, True being time. 1065 1066 Returns: 1067 A callable that can be used to return the appropriately formatted time expression. 1068 """ 1069 1070 def _builder(args: t.List): 1071 return exp_class( 1072 this=seq_get(args, 0), 1073 format=Dialect[dialect].format_time( 1074 seq_get(args, 1) 1075 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1076 ), 1077 ) 1078 1079 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1082def time_format( 1083 dialect: DialectType = None, 1084) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1085 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1086 """ 1087 Returns the time format for a given expression, unless it's equivalent 1088 to the default time format of the dialect of interest. 1089 """ 1090 time_format = self.format_time(expression) 1091 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1092 1093 return _time_format
1096def build_date_delta( 1097 exp_class: t.Type[E], 1098 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1099 default_unit: t.Optional[str] = "DAY", 1100) -> t.Callable[[t.List], E]: 1101 def _builder(args: t.List) -> E: 1102 unit_based = len(args) == 3 1103 this = args[2] if unit_based else seq_get(args, 0) 1104 unit = None 1105 if unit_based or default_unit: 1106 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1107 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1108 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1109 1110 return _builder
1113def build_date_delta_with_interval( 1114 expression_class: t.Type[E], 1115) -> t.Callable[[t.List], t.Optional[E]]: 1116 def _builder(args: t.List) -> t.Optional[E]: 1117 if len(args) < 2: 1118 return None 1119 1120 interval = args[1] 1121 1122 if not isinstance(interval, exp.Interval): 1123 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1124 1125 expression = interval.this 1126 if expression and expression.is_string: 1127 expression = exp.Literal.number(expression.this) 1128 1129 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1130 1131 return _builder
1134def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1135 unit = seq_get(args, 0) 1136 this = seq_get(args, 1) 1137 1138 if isinstance(this, exp.Cast) and this.is_type("date"): 1139 return exp.DateTrunc(unit=unit, this=this) 1140 return exp.TimestampTrunc(this=this, unit=unit)
1143def date_add_interval_sql( 1144 data_type: str, kind: str 1145) -> t.Callable[[Generator, exp.Expression], str]: 1146 def func(self: Generator, expression: exp.Expression) -> str: 1147 this = self.sql(expression, "this") 1148 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1149 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1150 1151 return func
1154def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1155 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1156 args = [unit_to_str(expression), expression.this] 1157 if zone: 1158 args.append(expression.args.get("zone")) 1159 return self.func("DATE_TRUNC", *args) 1160 1161 return _timestamptrunc_sql
1164def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1165 zone = expression.args.get("zone") 1166 if not zone: 1167 from sqlglot.optimizer.annotate_types import annotate_types 1168 1169 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1170 return self.sql(exp.cast(expression.this, target_type)) 1171 if zone.name.lower() in TIMEZONES: 1172 return self.sql( 1173 exp.AtTimeZone( 1174 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1175 zone=zone, 1176 ) 1177 ) 1178 return self.func("TIMESTAMP", expression.this, zone)
1181def no_time_sql(self: Generator, expression: exp.Time) -> str: 1182 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1183 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1184 expr = exp.cast( 1185 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1186 ) 1187 return self.sql(expr)
1190def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1191 this = expression.this 1192 expr = expression.expression 1193 1194 if expr.name.lower() in TIMEZONES: 1195 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1196 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1197 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1198 return self.sql(this) 1199 1200 this = exp.cast(this, exp.DataType.Type.DATE) 1201 expr = exp.cast(expr, exp.DataType.Type.TIME) 1202 1203 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1244def encode_decode_sql( 1245 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1246) -> str: 1247 charset = expression.args.get("charset") 1248 if charset and charset.name.lower() != "utf-8": 1249 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1250 1251 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1264def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1265 cond = expression.this 1266 1267 if isinstance(expression.this, exp.Distinct): 1268 cond = expression.this.expressions[0] 1269 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1270 1271 return self.func("sum", exp.func("if", cond, 1, 0))
1274def trim_sql(self: Generator, expression: exp.Trim) -> str: 1275 target = self.sql(expression, "this") 1276 trim_type = self.sql(expression, "position") 1277 remove_chars = self.sql(expression, "expression") 1278 collation = self.sql(expression, "collation") 1279 1280 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1281 if not remove_chars and not collation: 1282 return self.trim_sql(expression) 1283 1284 trim_type = f"{trim_type} " if trim_type else "" 1285 remove_chars = f"{remove_chars} " if remove_chars else "" 1286 from_part = "FROM " if trim_type or remove_chars else "" 1287 collation = f" COLLATE {collation}" if collation else "" 1288 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1309def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1310 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1311 if bad_args: 1312 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1313 1314 return self.func( 1315 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1316 )
1319def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1320 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1321 if bad_args: 1322 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1323 1324 return self.func( 1325 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1326 )
1329def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1330 names = [] 1331 for agg in aggregations: 1332 if isinstance(agg, exp.Alias): 1333 names.append(agg.alias) 1334 else: 1335 """ 1336 This case corresponds to aggregations without aliases being used as suffixes 1337 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1338 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1339 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1340 """ 1341 agg_all_unquoted = agg.transform( 1342 lambda node: ( 1343 exp.Identifier(this=node.name, quoted=False) 1344 if isinstance(node, exp.Identifier) 1345 else node 1346 ) 1347 ) 1348 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1349 1350 return names
1390def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1391 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1392 if expression.args.get("count"): 1393 self.unsupported(f"Only two arguments are supported in function {name}.") 1394 1395 return self.func(name, expression.this, expression.expression) 1396 1397 return _arg_max_or_min_sql
1400def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1401 this = expression.this.copy() 1402 1403 return_type = expression.return_type 1404 if return_type.is_type(exp.DataType.Type.DATE): 1405 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1406 # can truncate timestamp strings, because some dialects can't cast them to DATE 1407 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1408 1409 expression.this.replace(exp.cast(this, return_type)) 1410 return expression
1413def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1414 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1415 if cast and isinstance(expression, exp.TsOrDsAdd): 1416 expression = ts_or_ds_add_cast(expression) 1417 1418 return self.func( 1419 name, 1420 unit_to_var(expression), 1421 expression.expression, 1422 expression.this, 1423 ) 1424 1425 return _delta_sql
1428def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1429 unit = expression.args.get("unit") 1430 1431 if isinstance(unit, exp.Placeholder): 1432 return unit 1433 if unit: 1434 return exp.Literal.string(unit.name) 1435 return exp.Literal.string(default) if default else None
1465def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1466 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1467 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1468 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1469 1470 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1473def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1474 """Remove table refs from columns in when statements.""" 1475 alias = expression.this.args.get("alias") 1476 1477 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1478 return self.dialect.normalize_identifier(identifier).name if identifier else None 1479 1480 targets = {normalize(expression.this.this)} 1481 1482 if alias: 1483 targets.add(normalize(alias.this)) 1484 1485 for when in expression.expressions: 1486 when.transform( 1487 lambda node: ( 1488 exp.column(node.this) 1489 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1490 else node 1491 ), 1492 copy=False, 1493 ) 1494 1495 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1498def build_json_extract_path( 1499 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1500) -> t.Callable[[t.List], F]: 1501 def _builder(args: t.List) -> F: 1502 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1503 for arg in args[1:]: 1504 if not isinstance(arg, exp.Literal): 1505 # We use the fallback parser because we can't really transpile non-literals safely 1506 return expr_type.from_arg_list(args) 1507 1508 text = arg.name 1509 if is_int(text): 1510 index = int(text) 1511 segments.append( 1512 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1513 ) 1514 else: 1515 segments.append(exp.JSONPathKey(this=text)) 1516 1517 # This is done to avoid failing in the expression validator due to the arg count 1518 del args[2:] 1519 return expr_type( 1520 this=seq_get(args, 0), 1521 expression=exp.JSONPath(expressions=segments), 1522 only_json_types=arrow_req_json_type, 1523 ) 1524 1525 return _builder
1528def json_extract_segments( 1529 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1530) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1531 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1532 path = expression.expression 1533 if not isinstance(path, exp.JSONPath): 1534 return rename_func(name)(self, expression) 1535 1536 segments = [] 1537 for segment in path.expressions: 1538 path = self.sql(segment) 1539 if path: 1540 if isinstance(segment, exp.JSONPathPart) and ( 1541 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1542 ): 1543 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1544 1545 segments.append(path) 1546 1547 if op: 1548 return f" {op} ".join([self.sql(expression.this), *segments]) 1549 return self.func(name, expression.this, *segments) 1550 1551 return _json_extract_segments
1561def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1562 cond = expression.expression 1563 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1564 alias = cond.expressions[0] 1565 cond = cond.this 1566 elif isinstance(cond, exp.Predicate): 1567 alias = "_u" 1568 else: 1569 self.unsupported("Unsupported filter condition") 1570 return "" 1571 1572 unnest = exp.Unnest(expressions=[expression.this]) 1573 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1574 return self.sql(exp.Array(expressions=[filtered]))
1586def build_default_decimal_type( 1587 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1588) -> t.Callable[[exp.DataType], exp.DataType]: 1589 def _builder(dtype: exp.DataType) -> exp.DataType: 1590 if dtype.expressions or precision is None: 1591 return dtype 1592 1593 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1594 return exp.DataType.build(f"DECIMAL({params})") 1595 1596 return _builder
1599def build_timestamp_from_parts(args: t.List) -> exp.Func: 1600 if len(args) == 2: 1601 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1602 # so we parse this into Anonymous for now instead of introducing complexity 1603 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1604 1605 return exp.TimestampFromParts.from_arg_list(args)
1612def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1613 start = expression.args.get("start") 1614 end = expression.args.get("end") 1615 step = expression.args.get("step") 1616 1617 if isinstance(start, exp.Cast): 1618 target_type = start.to 1619 elif isinstance(end, exp.Cast): 1620 target_type = end.to 1621 else: 1622 target_type = None 1623 1624 if start and end and target_type and target_type.is_type("date", "timestamp"): 1625 if isinstance(start, exp.Cast) and target_type is start.to: 1626 end = exp.cast(end, target_type) 1627 else: 1628 start = exp.cast(start, target_type) 1629 1630 return self.func("SEQUENCE", start, end, step)