sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 base = seq_get(bases, 0) 137 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 138 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 139 base_parser = (getattr(base, "parser_class", Parser),) 140 base_generator = (getattr(base, "generator_class", Generator),) 141 142 klass.tokenizer_class = klass.__dict__.get( 143 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 144 ) 145 klass.jsonpath_tokenizer_class = klass.__dict__.get( 146 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 147 ) 148 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 149 klass.generator_class = klass.__dict__.get( 150 "Generator", type("Generator", base_generator, {}) 151 ) 152 153 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 154 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 155 klass.tokenizer_class._IDENTIFIERS.items() 156 )[0] 157 158 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 159 return next( 160 ( 161 (s, e) 162 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 163 if t == token_type 164 ), 165 (None, None), 166 ) 167 168 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 169 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 170 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 171 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 172 173 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 174 klass.UNESCAPED_SEQUENCES = { 175 **UNESCAPED_SEQUENCES, 176 **klass.UNESCAPED_SEQUENCES, 177 } 178 179 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 180 181 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 182 183 if enum not in ("", "bigquery"): 184 klass.generator_class.SELECT_KINDS = () 185 186 if enum not in ("", "athena", "presto", "trino"): 187 klass.generator_class.TRY_SUPPORTED = False 188 klass.generator_class.SUPPORTS_UESCAPE = False 189 190 if enum not in ("", "databricks", "hive", "spark", "spark2"): 191 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 192 for modifier in ("cluster", "distribute", "sort"): 193 modifier_transforms.pop(modifier, None) 194 195 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 196 197 if enum not in ("", "doris", "mysql"): 198 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 199 TokenType.STRAIGHT_JOIN, 200 } 201 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 202 TokenType.STRAIGHT_JOIN, 203 } 204 205 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 206 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 207 TokenType.ANTI, 208 TokenType.SEMI, 209 } 210 211 return klass 212 213 214class Dialect(metaclass=_Dialect): 215 INDEX_OFFSET = 0 216 """The base index offset for arrays.""" 217 218 WEEK_OFFSET = 0 219 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 220 221 UNNEST_COLUMN_ONLY = False 222 """Whether `UNNEST` table aliases are treated as column aliases.""" 223 224 ALIAS_POST_TABLESAMPLE = False 225 """Whether the table alias comes after tablesample.""" 226 227 TABLESAMPLE_SIZE_IS_PERCENT = False 228 """Whether a size in the table sample clause represents percentage.""" 229 230 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 231 """Specifies the strategy according to which identifiers should be normalized.""" 232 233 IDENTIFIERS_CAN_START_WITH_DIGIT = False 234 """Whether an unquoted identifier can start with a digit.""" 235 236 DPIPE_IS_STRING_CONCAT = True 237 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 238 239 STRICT_STRING_CONCAT = False 240 """Whether `CONCAT`'s arguments must be strings.""" 241 242 SUPPORTS_USER_DEFINED_TYPES = True 243 """Whether user-defined data types are supported.""" 244 245 SUPPORTS_SEMI_ANTI_JOIN = True 246 """Whether `SEMI` or `ANTI` joins are supported.""" 247 248 SUPPORTS_COLUMN_JOIN_MARKS = False 249 """Whether the old-style outer join (+) syntax is supported.""" 250 251 COPY_PARAMS_ARE_CSV = True 252 """Separator of COPY statement parameters.""" 253 254 NORMALIZE_FUNCTIONS: bool | str = "upper" 255 """ 256 Determines how function names are going to be normalized. 257 Possible values: 258 "upper" or True: Convert names to uppercase. 259 "lower": Convert names to lowercase. 260 False: Disables function name normalization. 261 """ 262 263 LOG_BASE_FIRST: t.Optional[bool] = True 264 """ 265 Whether the base comes first in the `LOG` function. 266 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 267 """ 268 269 NULL_ORDERING = "nulls_are_small" 270 """ 271 Default `NULL` ordering method to use if not explicitly set. 272 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 273 """ 274 275 TYPED_DIVISION = False 276 """ 277 Whether the behavior of `a / b` depends on the types of `a` and `b`. 278 False means `a / b` is always float division. 279 True means `a / b` is integer division if both `a` and `b` are integers. 280 """ 281 282 SAFE_DIVISION = False 283 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 284 285 CONCAT_COALESCE = False 286 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 287 288 HEX_LOWERCASE = False 289 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 290 291 DATE_FORMAT = "'%Y-%m-%d'" 292 DATEINT_FORMAT = "'%Y%m%d'" 293 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 294 295 TIME_MAPPING: t.Dict[str, str] = {} 296 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 297 298 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 299 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 300 FORMAT_MAPPING: t.Dict[str, str] = {} 301 """ 302 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 303 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 304 """ 305 306 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 307 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 308 309 PSEUDOCOLUMNS: t.Set[str] = set() 310 """ 311 Columns that are auto-generated by the engine corresponding to this dialect. 312 For example, such columns may be excluded from `SELECT *` queries. 313 """ 314 315 PREFER_CTE_ALIAS_COLUMN = False 316 """ 317 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 318 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 319 any projection aliases in the subquery. 320 321 For example, 322 WITH y(c) AS ( 323 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 324 ) SELECT c FROM y; 325 326 will be rewritten as 327 328 WITH y(c) AS ( 329 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 330 ) SELECT c FROM y; 331 """ 332 333 COPY_PARAMS_ARE_CSV = True 334 """ 335 Whether COPY statement parameters are separated by comma or whitespace 336 """ 337 338 FORCE_EARLY_ALIAS_REF_EXPANSION = False 339 """ 340 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 341 342 For example: 343 WITH data AS ( 344 SELECT 345 1 AS id, 346 2 AS my_id 347 ) 348 SELECT 349 id AS my_id 350 FROM 351 data 352 WHERE 353 my_id = 1 354 GROUP BY 355 my_id, 356 HAVING 357 my_id = 1 358 359 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 360 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 361 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 362 """ 363 364 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 365 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 366 367 SUPPORTS_ORDER_BY_ALL = False 368 """ 369 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 370 """ 371 372 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 373 """ 374 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 375 as the former is of type INT[] vs the latter which is SUPER 376 """ 377 378 SUPPORTS_FIXED_SIZE_ARRAYS = False 379 """ 380 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 381 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 382 """ 383 384 # --- Autofilled --- 385 386 tokenizer_class = Tokenizer 387 jsonpath_tokenizer_class = JSONPathTokenizer 388 parser_class = Parser 389 generator_class = Generator 390 391 # A trie of the time_mapping keys 392 TIME_TRIE: t.Dict = {} 393 FORMAT_TRIE: t.Dict = {} 394 395 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 396 INVERSE_TIME_TRIE: t.Dict = {} 397 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 398 INVERSE_FORMAT_TRIE: t.Dict = {} 399 400 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 401 402 # Delimiters for string literals and identifiers 403 QUOTE_START = "'" 404 QUOTE_END = "'" 405 IDENTIFIER_START = '"' 406 IDENTIFIER_END = '"' 407 408 # Delimiters for bit, hex, byte and unicode literals 409 BIT_START: t.Optional[str] = None 410 BIT_END: t.Optional[str] = None 411 HEX_START: t.Optional[str] = None 412 HEX_END: t.Optional[str] = None 413 BYTE_START: t.Optional[str] = None 414 BYTE_END: t.Optional[str] = None 415 UNICODE_START: t.Optional[str] = None 416 UNICODE_END: t.Optional[str] = None 417 418 DATE_PART_MAPPING = { 419 "Y": "YEAR", 420 "YY": "YEAR", 421 "YYY": "YEAR", 422 "YYYY": "YEAR", 423 "YR": "YEAR", 424 "YEARS": "YEAR", 425 "YRS": "YEAR", 426 "MM": "MONTH", 427 "MON": "MONTH", 428 "MONS": "MONTH", 429 "MONTHS": "MONTH", 430 "D": "DAY", 431 "DD": "DAY", 432 "DAYS": "DAY", 433 "DAYOFMONTH": "DAY", 434 "DAY OF WEEK": "DAYOFWEEK", 435 "WEEKDAY": "DAYOFWEEK", 436 "DOW": "DAYOFWEEK", 437 "DW": "DAYOFWEEK", 438 "WEEKDAY_ISO": "DAYOFWEEKISO", 439 "DOW_ISO": "DAYOFWEEKISO", 440 "DW_ISO": "DAYOFWEEKISO", 441 "DAY OF YEAR": "DAYOFYEAR", 442 "DOY": "DAYOFYEAR", 443 "DY": "DAYOFYEAR", 444 "W": "WEEK", 445 "WK": "WEEK", 446 "WEEKOFYEAR": "WEEK", 447 "WOY": "WEEK", 448 "WY": "WEEK", 449 "WEEK_ISO": "WEEKISO", 450 "WEEKOFYEARISO": "WEEKISO", 451 "WEEKOFYEAR_ISO": "WEEKISO", 452 "Q": "QUARTER", 453 "QTR": "QUARTER", 454 "QTRS": "QUARTER", 455 "QUARTERS": "QUARTER", 456 "H": "HOUR", 457 "HH": "HOUR", 458 "HR": "HOUR", 459 "HOURS": "HOUR", 460 "HRS": "HOUR", 461 "M": "MINUTE", 462 "MI": "MINUTE", 463 "MIN": "MINUTE", 464 "MINUTES": "MINUTE", 465 "MINS": "MINUTE", 466 "S": "SECOND", 467 "SEC": "SECOND", 468 "SECONDS": "SECOND", 469 "SECS": "SECOND", 470 "MS": "MILLISECOND", 471 "MSEC": "MILLISECOND", 472 "MSECS": "MILLISECOND", 473 "MSECOND": "MILLISECOND", 474 "MSECONDS": "MILLISECOND", 475 "MILLISEC": "MILLISECOND", 476 "MILLISECS": "MILLISECOND", 477 "MILLISECON": "MILLISECOND", 478 "MILLISECONDS": "MILLISECOND", 479 "US": "MICROSECOND", 480 "USEC": "MICROSECOND", 481 "USECS": "MICROSECOND", 482 "MICROSEC": "MICROSECOND", 483 "MICROSECS": "MICROSECOND", 484 "USECOND": "MICROSECOND", 485 "USECONDS": "MICROSECOND", 486 "MICROSECONDS": "MICROSECOND", 487 "NS": "NANOSECOND", 488 "NSEC": "NANOSECOND", 489 "NANOSEC": "NANOSECOND", 490 "NSECOND": "NANOSECOND", 491 "NSECONDS": "NANOSECOND", 492 "NANOSECS": "NANOSECOND", 493 "EPOCH_SECOND": "EPOCH", 494 "EPOCH_SECONDS": "EPOCH", 495 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 496 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 497 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 498 "TZH": "TIMEZONE_HOUR", 499 "TZM": "TIMEZONE_MINUTE", 500 "DEC": "DECADE", 501 "DECS": "DECADE", 502 "DECADES": "DECADE", 503 "MIL": "MILLENIUM", 504 "MILS": "MILLENIUM", 505 "MILLENIA": "MILLENIUM", 506 "C": "CENTURY", 507 "CENT": "CENTURY", 508 "CENTS": "CENTURY", 509 "CENTURIES": "CENTURY", 510 } 511 512 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 513 exp.DataType.Type.BIGINT: { 514 exp.ApproxDistinct, 515 exp.ArraySize, 516 exp.Count, 517 exp.Length, 518 }, 519 exp.DataType.Type.BOOLEAN: { 520 exp.Between, 521 exp.Boolean, 522 exp.In, 523 exp.RegexpLike, 524 }, 525 exp.DataType.Type.DATE: { 526 exp.CurrentDate, 527 exp.Date, 528 exp.DateFromParts, 529 exp.DateStrToDate, 530 exp.DiToDate, 531 exp.StrToDate, 532 exp.TimeStrToDate, 533 exp.TsOrDsToDate, 534 }, 535 exp.DataType.Type.DATETIME: { 536 exp.CurrentDatetime, 537 exp.Datetime, 538 exp.DatetimeAdd, 539 exp.DatetimeSub, 540 }, 541 exp.DataType.Type.DOUBLE: { 542 exp.ApproxQuantile, 543 exp.Avg, 544 exp.Div, 545 exp.Exp, 546 exp.Ln, 547 exp.Log, 548 exp.Pow, 549 exp.Quantile, 550 exp.Round, 551 exp.SafeDivide, 552 exp.Sqrt, 553 exp.Stddev, 554 exp.StddevPop, 555 exp.StddevSamp, 556 exp.Variance, 557 exp.VariancePop, 558 }, 559 exp.DataType.Type.INT: { 560 exp.Ceil, 561 exp.DatetimeDiff, 562 exp.DateDiff, 563 exp.TimestampDiff, 564 exp.TimeDiff, 565 exp.DateToDi, 566 exp.Levenshtein, 567 exp.Sign, 568 exp.StrPosition, 569 exp.TsOrDiToDi, 570 }, 571 exp.DataType.Type.JSON: { 572 exp.ParseJSON, 573 }, 574 exp.DataType.Type.TIME: { 575 exp.Time, 576 }, 577 exp.DataType.Type.TIMESTAMP: { 578 exp.CurrentTime, 579 exp.CurrentTimestamp, 580 exp.StrToTime, 581 exp.TimeAdd, 582 exp.TimeStrToTime, 583 exp.TimeSub, 584 exp.TimestampAdd, 585 exp.TimestampSub, 586 exp.UnixToTime, 587 }, 588 exp.DataType.Type.TINYINT: { 589 exp.Day, 590 exp.Month, 591 exp.Week, 592 exp.Year, 593 exp.Quarter, 594 }, 595 exp.DataType.Type.VARCHAR: { 596 exp.ArrayConcat, 597 exp.Concat, 598 exp.ConcatWs, 599 exp.DateToDateStr, 600 exp.GroupConcat, 601 exp.Initcap, 602 exp.Lower, 603 exp.Substring, 604 exp.TimeToStr, 605 exp.TimeToTimeStr, 606 exp.Trim, 607 exp.TsOrDsToDateStr, 608 exp.UnixToStr, 609 exp.UnixToTimeStr, 610 exp.Upper, 611 }, 612 } 613 614 ANNOTATORS: AnnotatorsType = { 615 **{ 616 expr_type: lambda self, e: self._annotate_unary(e) 617 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 618 }, 619 **{ 620 expr_type: lambda self, e: self._annotate_binary(e) 621 for expr_type in subclasses(exp.__name__, exp.Binary) 622 }, 623 **{ 624 expr_type: _annotate_with_type_lambda(data_type) 625 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 626 for expr_type in expressions 627 }, 628 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 629 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 630 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 631 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 632 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 633 exp.Bracket: lambda self, e: self._annotate_bracket(e), 634 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 635 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 636 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 637 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 638 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 639 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 640 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 641 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 642 exp.Div: lambda self, e: self._annotate_div(e), 643 exp.Dot: lambda self, e: self._annotate_dot(e), 644 exp.Explode: lambda self, e: self._annotate_explode(e), 645 exp.Extract: lambda self, e: self._annotate_extract(e), 646 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 647 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 648 e, exp.DataType.build("ARRAY<DATE>") 649 ), 650 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 651 e, exp.DataType.build("ARRAY<TIMESTAMP>") 652 ), 653 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 654 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 655 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 656 exp.Literal: lambda self, e: self._annotate_literal(e), 657 exp.Map: lambda self, e: self._annotate_map(e), 658 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 659 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 660 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 661 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 662 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 663 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 664 exp.Struct: lambda self, e: self._annotate_struct(e), 665 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 666 exp.Timestamp: lambda self, e: self._annotate_with_type( 667 e, 668 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 669 ), 670 exp.ToMap: lambda self, e: self._annotate_to_map(e), 671 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 672 exp.Unnest: lambda self, e: self._annotate_unnest(e), 673 exp.VarMap: lambda self, e: self._annotate_map(e), 674 } 675 676 @classmethod 677 def get_or_raise(cls, dialect: DialectType) -> Dialect: 678 """ 679 Look up a dialect in the global dialect registry and return it if it exists. 680 681 Args: 682 dialect: The target dialect. If this is a string, it can be optionally followed by 683 additional key-value pairs that are separated by commas and are used to specify 684 dialect settings, such as whether the dialect's identifiers are case-sensitive. 685 686 Example: 687 >>> dialect = dialect_class = get_or_raise("duckdb") 688 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 689 690 Returns: 691 The corresponding Dialect instance. 692 """ 693 694 if not dialect: 695 return cls() 696 if isinstance(dialect, _Dialect): 697 return dialect() 698 if isinstance(dialect, Dialect): 699 return dialect 700 if isinstance(dialect, str): 701 try: 702 dialect_name, *kv_strings = dialect.split(",") 703 kv_pairs = (kv.split("=") for kv in kv_strings) 704 kwargs = {} 705 for pair in kv_pairs: 706 key = pair[0].strip() 707 value: t.Union[bool | str | None] = None 708 709 if len(pair) == 1: 710 # Default initialize standalone settings to True 711 value = True 712 elif len(pair) == 2: 713 value = pair[1].strip() 714 715 # Coerce the value to boolean if it matches to the truthy/falsy values below 716 value_lower = value.lower() 717 if value_lower in ("true", "1"): 718 value = True 719 elif value_lower in ("false", "0"): 720 value = False 721 722 kwargs[key] = value 723 724 except ValueError: 725 raise ValueError( 726 f"Invalid dialect format: '{dialect}'. " 727 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 728 ) 729 730 result = cls.get(dialect_name.strip()) 731 if not result: 732 from difflib import get_close_matches 733 734 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 735 if similar: 736 similar = f" Did you mean {similar}?" 737 738 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 739 740 return result(**kwargs) 741 742 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 743 744 @classmethod 745 def format_time( 746 cls, expression: t.Optional[str | exp.Expression] 747 ) -> t.Optional[exp.Expression]: 748 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 749 if isinstance(expression, str): 750 return exp.Literal.string( 751 # the time formats are quoted 752 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 753 ) 754 755 if expression and expression.is_string: 756 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 757 758 return expression 759 760 def __init__(self, **kwargs) -> None: 761 normalization_strategy = kwargs.pop("normalization_strategy", None) 762 763 if normalization_strategy is None: 764 self.normalization_strategy = self.NORMALIZATION_STRATEGY 765 else: 766 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 767 768 self.settings = kwargs 769 770 def __eq__(self, other: t.Any) -> bool: 771 # Does not currently take dialect state into account 772 return type(self) == other 773 774 def __hash__(self) -> int: 775 # Does not currently take dialect state into account 776 return hash(type(self)) 777 778 def normalize_identifier(self, expression: E) -> E: 779 """ 780 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 781 782 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 783 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 784 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 785 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 786 787 There are also dialects like Spark, which are case-insensitive even when quotes are 788 present, and dialects like MySQL, whose resolution rules match those employed by the 789 underlying operating system, for example they may always be case-sensitive in Linux. 790 791 Finally, the normalization behavior of some engines can even be controlled through flags, 792 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 793 794 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 795 that it can analyze queries in the optimizer and successfully capture their semantics. 796 """ 797 if ( 798 isinstance(expression, exp.Identifier) 799 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 800 and ( 801 not expression.quoted 802 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 803 ) 804 ): 805 expression.set( 806 "this", 807 ( 808 expression.this.upper() 809 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 810 else expression.this.lower() 811 ), 812 ) 813 814 return expression 815 816 def case_sensitive(self, text: str) -> bool: 817 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 818 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 819 return False 820 821 unsafe = ( 822 str.islower 823 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 824 else str.isupper 825 ) 826 return any(unsafe(char) for char in text) 827 828 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 829 """Checks if text can be identified given an identify option. 830 831 Args: 832 text: The text to check. 833 identify: 834 `"always"` or `True`: Always returns `True`. 835 `"safe"`: Only returns `True` if the identifier is case-insensitive. 836 837 Returns: 838 Whether the given text can be identified. 839 """ 840 if identify is True or identify == "always": 841 return True 842 843 if identify == "safe": 844 return not self.case_sensitive(text) 845 846 return False 847 848 def quote_identifier(self, expression: E, identify: bool = True) -> E: 849 """ 850 Adds quotes to a given identifier. 851 852 Args: 853 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 854 identify: If set to `False`, the quotes will only be added if the identifier is deemed 855 "unsafe", with respect to its characters and this dialect's normalization strategy. 856 """ 857 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 858 name = expression.this 859 expression.set( 860 "quoted", 861 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 862 ) 863 864 return expression 865 866 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 867 if isinstance(path, exp.Literal): 868 path_text = path.name 869 if path.is_number: 870 path_text = f"[{path_text}]" 871 try: 872 return parse_json_path(path_text, self) 873 except ParseError as e: 874 logger.warning(f"Invalid JSON path syntax. {str(e)}") 875 876 return path 877 878 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 879 return self.parser(**opts).parse(self.tokenize(sql), sql) 880 881 def parse_into( 882 self, expression_type: exp.IntoType, sql: str, **opts 883 ) -> t.List[t.Optional[exp.Expression]]: 884 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 885 886 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 887 return self.generator(**opts).generate(expression, copy=copy) 888 889 def transpile(self, sql: str, **opts) -> t.List[str]: 890 return [ 891 self.generate(expression, copy=False, **opts) if expression else "" 892 for expression in self.parse(sql) 893 ] 894 895 def tokenize(self, sql: str) -> t.List[Token]: 896 return self.tokenizer.tokenize(sql) 897 898 @property 899 def tokenizer(self) -> Tokenizer: 900 return self.tokenizer_class(dialect=self) 901 902 @property 903 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 904 return self.jsonpath_tokenizer_class(dialect=self) 905 906 def parser(self, **opts) -> Parser: 907 return self.parser_class(dialect=self, **opts) 908 909 def generator(self, **opts) -> Generator: 910 return self.generator_class(dialect=self, **opts) 911 912 913DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 914 915 916def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 917 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 918 919 920def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 921 if expression.args.get("accuracy"): 922 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 923 return self.func("APPROX_COUNT_DISTINCT", expression.this) 924 925 926def if_sql( 927 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 928) -> t.Callable[[Generator, exp.If], str]: 929 def _if_sql(self: Generator, expression: exp.If) -> str: 930 return self.func( 931 name, 932 expression.this, 933 expression.args.get("true"), 934 expression.args.get("false") or false_value, 935 ) 936 937 return _if_sql 938 939 940def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 941 this = expression.this 942 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 943 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 944 945 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 946 947 948def inline_array_sql(self: Generator, expression: exp.Array) -> str: 949 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 950 951 952def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 953 elem = seq_get(expression.expressions, 0) 954 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 955 return self.func("ARRAY", elem) 956 return inline_array_sql(self, expression) 957 958 959def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 960 return self.like_sql( 961 exp.Like( 962 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 963 ) 964 ) 965 966 967def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 968 zone = self.sql(expression, "this") 969 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 970 971 972def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 973 if expression.args.get("recursive"): 974 self.unsupported("Recursive CTEs are unsupported") 975 expression.args["recursive"] = False 976 return self.with_sql(expression) 977 978 979def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 980 n = self.sql(expression, "this") 981 d = self.sql(expression, "expression") 982 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 983 984 985def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 986 self.unsupported("TABLESAMPLE unsupported") 987 return self.sql(expression.this) 988 989 990def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 991 self.unsupported("PIVOT unsupported") 992 return "" 993 994 995def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 996 return self.cast_sql(expression) 997 998 999def no_comment_column_constraint_sql( 1000 self: Generator, expression: exp.CommentColumnConstraint 1001) -> str: 1002 self.unsupported("CommentColumnConstraint unsupported") 1003 return "" 1004 1005 1006def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1007 self.unsupported("MAP_FROM_ENTRIES unsupported") 1008 return "" 1009 1010 1011def str_position_sql( 1012 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1013) -> str: 1014 this = self.sql(expression, "this") 1015 substr = self.sql(expression, "substr") 1016 position = self.sql(expression, "position") 1017 instance = expression.args.get("instance") if generate_instance else None 1018 position_offset = "" 1019 1020 if position: 1021 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1022 this = self.func("SUBSTR", this, position) 1023 position_offset = f" + {position} - 1" 1024 1025 return self.func("STRPOS", this, substr, instance) + position_offset 1026 1027 1028def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1029 return ( 1030 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1031 ) 1032 1033 1034def var_map_sql( 1035 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1036) -> str: 1037 keys = expression.args["keys"] 1038 values = expression.args["values"] 1039 1040 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1041 self.unsupported("Cannot convert array columns into map.") 1042 return self.func(map_func_name, keys, values) 1043 1044 args = [] 1045 for key, value in zip(keys.expressions, values.expressions): 1046 args.append(self.sql(key)) 1047 args.append(self.sql(value)) 1048 1049 return self.func(map_func_name, *args) 1050 1051 1052def build_formatted_time( 1053 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1054) -> t.Callable[[t.List], E]: 1055 """Helper used for time expressions. 1056 1057 Args: 1058 exp_class: the expression class to instantiate. 1059 dialect: target sql dialect. 1060 default: the default format, True being time. 1061 1062 Returns: 1063 A callable that can be used to return the appropriately formatted time expression. 1064 """ 1065 1066 def _builder(args: t.List): 1067 return exp_class( 1068 this=seq_get(args, 0), 1069 format=Dialect[dialect].format_time( 1070 seq_get(args, 1) 1071 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1072 ), 1073 ) 1074 1075 return _builder 1076 1077 1078def time_format( 1079 dialect: DialectType = None, 1080) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1081 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1082 """ 1083 Returns the time format for a given expression, unless it's equivalent 1084 to the default time format of the dialect of interest. 1085 """ 1086 time_format = self.format_time(expression) 1087 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1088 1089 return _time_format 1090 1091 1092def build_date_delta( 1093 exp_class: t.Type[E], 1094 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1095 default_unit: t.Optional[str] = "DAY", 1096) -> t.Callable[[t.List], E]: 1097 def _builder(args: t.List) -> E: 1098 unit_based = len(args) == 3 1099 this = args[2] if unit_based else seq_get(args, 0) 1100 unit = None 1101 if unit_based or default_unit: 1102 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1103 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1104 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1105 1106 return _builder 1107 1108 1109def build_date_delta_with_interval( 1110 expression_class: t.Type[E], 1111) -> t.Callable[[t.List], t.Optional[E]]: 1112 def _builder(args: t.List) -> t.Optional[E]: 1113 if len(args) < 2: 1114 return None 1115 1116 interval = args[1] 1117 1118 if not isinstance(interval, exp.Interval): 1119 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1120 1121 expression = interval.this 1122 if expression and expression.is_string: 1123 expression = exp.Literal.number(expression.this) 1124 1125 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1126 1127 return _builder 1128 1129 1130def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1131 unit = seq_get(args, 0) 1132 this = seq_get(args, 1) 1133 1134 if isinstance(this, exp.Cast) and this.is_type("date"): 1135 return exp.DateTrunc(unit=unit, this=this) 1136 return exp.TimestampTrunc(this=this, unit=unit) 1137 1138 1139def date_add_interval_sql( 1140 data_type: str, kind: str 1141) -> t.Callable[[Generator, exp.Expression], str]: 1142 def func(self: Generator, expression: exp.Expression) -> str: 1143 this = self.sql(expression, "this") 1144 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1145 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1146 1147 return func 1148 1149 1150def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1151 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1152 args = [unit_to_str(expression), expression.this] 1153 if zone: 1154 args.append(expression.args.get("zone")) 1155 return self.func("DATE_TRUNC", *args) 1156 1157 return _timestamptrunc_sql 1158 1159 1160def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1161 zone = expression.args.get("zone") 1162 if not zone: 1163 from sqlglot.optimizer.annotate_types import annotate_types 1164 1165 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1166 return self.sql(exp.cast(expression.this, target_type)) 1167 if zone.name.lower() in TIMEZONES: 1168 return self.sql( 1169 exp.AtTimeZone( 1170 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1171 zone=zone, 1172 ) 1173 ) 1174 return self.func("TIMESTAMP", expression.this, zone) 1175 1176 1177def no_time_sql(self: Generator, expression: exp.Time) -> str: 1178 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1179 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1180 expr = exp.cast( 1181 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1182 ) 1183 return self.sql(expr) 1184 1185 1186def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1187 this = expression.this 1188 expr = expression.expression 1189 1190 if expr.name.lower() in TIMEZONES: 1191 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1192 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1193 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1194 return self.sql(this) 1195 1196 this = exp.cast(this, exp.DataType.Type.DATE) 1197 expr = exp.cast(expr, exp.DataType.Type.TIME) 1198 1199 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1200 1201 1202def locate_to_strposition(args: t.List) -> exp.Expression: 1203 return exp.StrPosition( 1204 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1205 ) 1206 1207 1208def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1209 return self.func( 1210 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1211 ) 1212 1213 1214def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1215 return self.sql( 1216 exp.Substring( 1217 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1218 ) 1219 ) 1220 1221 1222def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1223 return self.sql( 1224 exp.Substring( 1225 this=expression.this, 1226 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1227 ) 1228 ) 1229 1230 1231def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1232 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1233 1234 1235def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1236 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1237 1238 1239# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1240def encode_decode_sql( 1241 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1242) -> str: 1243 charset = expression.args.get("charset") 1244 if charset and charset.name.lower() != "utf-8": 1245 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1246 1247 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1248 1249 1250def min_or_least(self: Generator, expression: exp.Min) -> str: 1251 name = "LEAST" if expression.expressions else "MIN" 1252 return rename_func(name)(self, expression) 1253 1254 1255def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1256 name = "GREATEST" if expression.expressions else "MAX" 1257 return rename_func(name)(self, expression) 1258 1259 1260def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1261 cond = expression.this 1262 1263 if isinstance(expression.this, exp.Distinct): 1264 cond = expression.this.expressions[0] 1265 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1266 1267 return self.func("sum", exp.func("if", cond, 1, 0)) 1268 1269 1270def trim_sql(self: Generator, expression: exp.Trim) -> str: 1271 target = self.sql(expression, "this") 1272 trim_type = self.sql(expression, "position") 1273 remove_chars = self.sql(expression, "expression") 1274 collation = self.sql(expression, "collation") 1275 1276 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1277 if not remove_chars and not collation: 1278 return self.trim_sql(expression) 1279 1280 trim_type = f"{trim_type} " if trim_type else "" 1281 remove_chars = f"{remove_chars} " if remove_chars else "" 1282 from_part = "FROM " if trim_type or remove_chars else "" 1283 collation = f" COLLATE {collation}" if collation else "" 1284 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1285 1286 1287def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1288 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1289 1290 1291def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1292 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1293 1294 1295def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1296 delim, *rest_args = expression.expressions 1297 return self.sql( 1298 reduce( 1299 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1300 rest_args, 1301 ) 1302 ) 1303 1304 1305def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1306 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1307 if bad_args: 1308 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1309 1310 return self.func( 1311 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1312 ) 1313 1314 1315def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1316 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1317 if bad_args: 1318 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1319 1320 return self.func( 1321 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1322 ) 1323 1324 1325def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1326 names = [] 1327 for agg in aggregations: 1328 if isinstance(agg, exp.Alias): 1329 names.append(agg.alias) 1330 else: 1331 """ 1332 This case corresponds to aggregations without aliases being used as suffixes 1333 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1334 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1335 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1336 """ 1337 agg_all_unquoted = agg.transform( 1338 lambda node: ( 1339 exp.Identifier(this=node.name, quoted=False) 1340 if isinstance(node, exp.Identifier) 1341 else node 1342 ) 1343 ) 1344 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1345 1346 return names 1347 1348 1349def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1350 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1351 1352 1353# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1354def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1355 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1356 1357 1358def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1359 return self.func("MAX", expression.this) 1360 1361 1362def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1363 a = self.sql(expression.left) 1364 b = self.sql(expression.right) 1365 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1366 1367 1368def is_parse_json(expression: exp.Expression) -> bool: 1369 return isinstance(expression, exp.ParseJSON) or ( 1370 isinstance(expression, exp.Cast) and expression.is_type("json") 1371 ) 1372 1373 1374def isnull_to_is_null(args: t.List) -> exp.Expression: 1375 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1376 1377 1378def generatedasidentitycolumnconstraint_sql( 1379 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1380) -> str: 1381 start = self.sql(expression, "start") or "1" 1382 increment = self.sql(expression, "increment") or "1" 1383 return f"IDENTITY({start}, {increment})" 1384 1385 1386def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1387 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1388 if expression.args.get("count"): 1389 self.unsupported(f"Only two arguments are supported in function {name}.") 1390 1391 return self.func(name, expression.this, expression.expression) 1392 1393 return _arg_max_or_min_sql 1394 1395 1396def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1397 this = expression.this.copy() 1398 1399 return_type = expression.return_type 1400 if return_type.is_type(exp.DataType.Type.DATE): 1401 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1402 # can truncate timestamp strings, because some dialects can't cast them to DATE 1403 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1404 1405 expression.this.replace(exp.cast(this, return_type)) 1406 return expression 1407 1408 1409def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1410 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1411 if cast and isinstance(expression, exp.TsOrDsAdd): 1412 expression = ts_or_ds_add_cast(expression) 1413 1414 return self.func( 1415 name, 1416 unit_to_var(expression), 1417 expression.expression, 1418 expression.this, 1419 ) 1420 1421 return _delta_sql 1422 1423 1424def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1425 unit = expression.args.get("unit") 1426 1427 if isinstance(unit, exp.Placeholder): 1428 return unit 1429 if unit: 1430 return exp.Literal.string(unit.name) 1431 return exp.Literal.string(default) if default else None 1432 1433 1434def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1435 unit = expression.args.get("unit") 1436 1437 if isinstance(unit, (exp.Var, exp.Placeholder)): 1438 return unit 1439 return exp.Var(this=default) if default else None 1440 1441 1442@t.overload 1443def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1444 pass 1445 1446 1447@t.overload 1448def map_date_part( 1449 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1450) -> t.Optional[exp.Expression]: 1451 pass 1452 1453 1454def map_date_part(part, dialect: DialectType = Dialect): 1455 mapped = ( 1456 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1457 ) 1458 return exp.var(mapped) if mapped else part 1459 1460 1461def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1462 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1463 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1464 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1465 1466 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1467 1468 1469def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1470 """Remove table refs from columns in when statements.""" 1471 alias = expression.this.args.get("alias") 1472 1473 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1474 return self.dialect.normalize_identifier(identifier).name if identifier else None 1475 1476 targets = {normalize(expression.this.this)} 1477 1478 if alias: 1479 targets.add(normalize(alias.this)) 1480 1481 for when in expression.expressions: 1482 when.transform( 1483 lambda node: ( 1484 exp.column(node.this) 1485 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1486 else node 1487 ), 1488 copy=False, 1489 ) 1490 1491 return self.merge_sql(expression) 1492 1493 1494def build_json_extract_path( 1495 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1496) -> t.Callable[[t.List], F]: 1497 def _builder(args: t.List) -> F: 1498 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1499 for arg in args[1:]: 1500 if not isinstance(arg, exp.Literal): 1501 # We use the fallback parser because we can't really transpile non-literals safely 1502 return expr_type.from_arg_list(args) 1503 1504 text = arg.name 1505 if is_int(text): 1506 index = int(text) 1507 segments.append( 1508 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1509 ) 1510 else: 1511 segments.append(exp.JSONPathKey(this=text)) 1512 1513 # This is done to avoid failing in the expression validator due to the arg count 1514 del args[2:] 1515 return expr_type( 1516 this=seq_get(args, 0), 1517 expression=exp.JSONPath(expressions=segments), 1518 only_json_types=arrow_req_json_type, 1519 ) 1520 1521 return _builder 1522 1523 1524def json_extract_segments( 1525 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1526) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1527 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1528 path = expression.expression 1529 if not isinstance(path, exp.JSONPath): 1530 return rename_func(name)(self, expression) 1531 1532 segments = [] 1533 for segment in path.expressions: 1534 path = self.sql(segment) 1535 if path: 1536 if isinstance(segment, exp.JSONPathPart) and ( 1537 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1538 ): 1539 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1540 1541 segments.append(path) 1542 1543 if op: 1544 return f" {op} ".join([self.sql(expression.this), *segments]) 1545 return self.func(name, expression.this, *segments) 1546 1547 return _json_extract_segments 1548 1549 1550def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1551 if isinstance(expression.this, exp.JSONPathWildcard): 1552 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1553 1554 return expression.name 1555 1556 1557def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1558 cond = expression.expression 1559 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1560 alias = cond.expressions[0] 1561 cond = cond.this 1562 elif isinstance(cond, exp.Predicate): 1563 alias = "_u" 1564 else: 1565 self.unsupported("Unsupported filter condition") 1566 return "" 1567 1568 unnest = exp.Unnest(expressions=[expression.this]) 1569 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1570 return self.sql(exp.Array(expressions=[filtered])) 1571 1572 1573def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1574 return self.func( 1575 "TO_NUMBER", 1576 expression.this, 1577 expression.args.get("format"), 1578 expression.args.get("nlsparam"), 1579 ) 1580 1581 1582def build_default_decimal_type( 1583 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1584) -> t.Callable[[exp.DataType], exp.DataType]: 1585 def _builder(dtype: exp.DataType) -> exp.DataType: 1586 if dtype.expressions or precision is None: 1587 return dtype 1588 1589 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1590 return exp.DataType.build(f"DECIMAL({params})") 1591 1592 return _builder 1593 1594 1595def build_timestamp_from_parts(args: t.List) -> exp.Func: 1596 if len(args) == 2: 1597 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1598 # so we parse this into Anonymous for now instead of introducing complexity 1599 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1600 1601 return exp.TimestampFromParts.from_arg_list(args) 1602 1603 1604def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1605 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1606 1607 1608def sequence_sql(self: Generator, expression: exp.GenerateSeries): 1609 start = expression.args["start"] 1610 end = expression.args["end"] 1611 step = expression.args.get("step") 1612 1613 if isinstance(start, exp.Cast): 1614 target_type = start.to 1615 elif isinstance(end, exp.Cast): 1616 target_type = end.to 1617 else: 1618 target_type = None 1619 1620 if target_type and target_type.is_type("timestamp"): 1621 if target_type is start.to: 1622 end = exp.cast(end, target_type) 1623 else: 1624 start = exp.cast(start, target_type) 1625 1626 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
215class Dialect(metaclass=_Dialect): 216 INDEX_OFFSET = 0 217 """The base index offset for arrays.""" 218 219 WEEK_OFFSET = 0 220 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 221 222 UNNEST_COLUMN_ONLY = False 223 """Whether `UNNEST` table aliases are treated as column aliases.""" 224 225 ALIAS_POST_TABLESAMPLE = False 226 """Whether the table alias comes after tablesample.""" 227 228 TABLESAMPLE_SIZE_IS_PERCENT = False 229 """Whether a size in the table sample clause represents percentage.""" 230 231 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 232 """Specifies the strategy according to which identifiers should be normalized.""" 233 234 IDENTIFIERS_CAN_START_WITH_DIGIT = False 235 """Whether an unquoted identifier can start with a digit.""" 236 237 DPIPE_IS_STRING_CONCAT = True 238 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 239 240 STRICT_STRING_CONCAT = False 241 """Whether `CONCAT`'s arguments must be strings.""" 242 243 SUPPORTS_USER_DEFINED_TYPES = True 244 """Whether user-defined data types are supported.""" 245 246 SUPPORTS_SEMI_ANTI_JOIN = True 247 """Whether `SEMI` or `ANTI` joins are supported.""" 248 249 SUPPORTS_COLUMN_JOIN_MARKS = False 250 """Whether the old-style outer join (+) syntax is supported.""" 251 252 COPY_PARAMS_ARE_CSV = True 253 """Separator of COPY statement parameters.""" 254 255 NORMALIZE_FUNCTIONS: bool | str = "upper" 256 """ 257 Determines how function names are going to be normalized. 258 Possible values: 259 "upper" or True: Convert names to uppercase. 260 "lower": Convert names to lowercase. 261 False: Disables function name normalization. 262 """ 263 264 LOG_BASE_FIRST: t.Optional[bool] = True 265 """ 266 Whether the base comes first in the `LOG` function. 267 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 268 """ 269 270 NULL_ORDERING = "nulls_are_small" 271 """ 272 Default `NULL` ordering method to use if not explicitly set. 273 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 274 """ 275 276 TYPED_DIVISION = False 277 """ 278 Whether the behavior of `a / b` depends on the types of `a` and `b`. 279 False means `a / b` is always float division. 280 True means `a / b` is integer division if both `a` and `b` are integers. 281 """ 282 283 SAFE_DIVISION = False 284 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 285 286 CONCAT_COALESCE = False 287 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 288 289 HEX_LOWERCASE = False 290 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 291 292 DATE_FORMAT = "'%Y-%m-%d'" 293 DATEINT_FORMAT = "'%Y%m%d'" 294 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 295 296 TIME_MAPPING: t.Dict[str, str] = {} 297 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 298 299 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 300 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 301 FORMAT_MAPPING: t.Dict[str, str] = {} 302 """ 303 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 304 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 305 """ 306 307 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 308 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 309 310 PSEUDOCOLUMNS: t.Set[str] = set() 311 """ 312 Columns that are auto-generated by the engine corresponding to this dialect. 313 For example, such columns may be excluded from `SELECT *` queries. 314 """ 315 316 PREFER_CTE_ALIAS_COLUMN = False 317 """ 318 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 319 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 320 any projection aliases in the subquery. 321 322 For example, 323 WITH y(c) AS ( 324 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 325 ) SELECT c FROM y; 326 327 will be rewritten as 328 329 WITH y(c) AS ( 330 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 """ 333 334 COPY_PARAMS_ARE_CSV = True 335 """ 336 Whether COPY statement parameters are separated by comma or whitespace 337 """ 338 339 FORCE_EARLY_ALIAS_REF_EXPANSION = False 340 """ 341 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 342 343 For example: 344 WITH data AS ( 345 SELECT 346 1 AS id, 347 2 AS my_id 348 ) 349 SELECT 350 id AS my_id 351 FROM 352 data 353 WHERE 354 my_id = 1 355 GROUP BY 356 my_id, 357 HAVING 358 my_id = 1 359 360 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 361 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 362 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 363 """ 364 365 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 366 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 367 368 SUPPORTS_ORDER_BY_ALL = False 369 """ 370 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 371 """ 372 373 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 374 """ 375 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 376 as the former is of type INT[] vs the latter which is SUPER 377 """ 378 379 SUPPORTS_FIXED_SIZE_ARRAYS = False 380 """ 381 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 382 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 383 """ 384 385 # --- Autofilled --- 386 387 tokenizer_class = Tokenizer 388 jsonpath_tokenizer_class = JSONPathTokenizer 389 parser_class = Parser 390 generator_class = Generator 391 392 # A trie of the time_mapping keys 393 TIME_TRIE: t.Dict = {} 394 FORMAT_TRIE: t.Dict = {} 395 396 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 397 INVERSE_TIME_TRIE: t.Dict = {} 398 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 399 INVERSE_FORMAT_TRIE: t.Dict = {} 400 401 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 402 403 # Delimiters for string literals and identifiers 404 QUOTE_START = "'" 405 QUOTE_END = "'" 406 IDENTIFIER_START = '"' 407 IDENTIFIER_END = '"' 408 409 # Delimiters for bit, hex, byte and unicode literals 410 BIT_START: t.Optional[str] = None 411 BIT_END: t.Optional[str] = None 412 HEX_START: t.Optional[str] = None 413 HEX_END: t.Optional[str] = None 414 BYTE_START: t.Optional[str] = None 415 BYTE_END: t.Optional[str] = None 416 UNICODE_START: t.Optional[str] = None 417 UNICODE_END: t.Optional[str] = None 418 419 DATE_PART_MAPPING = { 420 "Y": "YEAR", 421 "YY": "YEAR", 422 "YYY": "YEAR", 423 "YYYY": "YEAR", 424 "YR": "YEAR", 425 "YEARS": "YEAR", 426 "YRS": "YEAR", 427 "MM": "MONTH", 428 "MON": "MONTH", 429 "MONS": "MONTH", 430 "MONTHS": "MONTH", 431 "D": "DAY", 432 "DD": "DAY", 433 "DAYS": "DAY", 434 "DAYOFMONTH": "DAY", 435 "DAY OF WEEK": "DAYOFWEEK", 436 "WEEKDAY": "DAYOFWEEK", 437 "DOW": "DAYOFWEEK", 438 "DW": "DAYOFWEEK", 439 "WEEKDAY_ISO": "DAYOFWEEKISO", 440 "DOW_ISO": "DAYOFWEEKISO", 441 "DW_ISO": "DAYOFWEEKISO", 442 "DAY OF YEAR": "DAYOFYEAR", 443 "DOY": "DAYOFYEAR", 444 "DY": "DAYOFYEAR", 445 "W": "WEEK", 446 "WK": "WEEK", 447 "WEEKOFYEAR": "WEEK", 448 "WOY": "WEEK", 449 "WY": "WEEK", 450 "WEEK_ISO": "WEEKISO", 451 "WEEKOFYEARISO": "WEEKISO", 452 "WEEKOFYEAR_ISO": "WEEKISO", 453 "Q": "QUARTER", 454 "QTR": "QUARTER", 455 "QTRS": "QUARTER", 456 "QUARTERS": "QUARTER", 457 "H": "HOUR", 458 "HH": "HOUR", 459 "HR": "HOUR", 460 "HOURS": "HOUR", 461 "HRS": "HOUR", 462 "M": "MINUTE", 463 "MI": "MINUTE", 464 "MIN": "MINUTE", 465 "MINUTES": "MINUTE", 466 "MINS": "MINUTE", 467 "S": "SECOND", 468 "SEC": "SECOND", 469 "SECONDS": "SECOND", 470 "SECS": "SECOND", 471 "MS": "MILLISECOND", 472 "MSEC": "MILLISECOND", 473 "MSECS": "MILLISECOND", 474 "MSECOND": "MILLISECOND", 475 "MSECONDS": "MILLISECOND", 476 "MILLISEC": "MILLISECOND", 477 "MILLISECS": "MILLISECOND", 478 "MILLISECON": "MILLISECOND", 479 "MILLISECONDS": "MILLISECOND", 480 "US": "MICROSECOND", 481 "USEC": "MICROSECOND", 482 "USECS": "MICROSECOND", 483 "MICROSEC": "MICROSECOND", 484 "MICROSECS": "MICROSECOND", 485 "USECOND": "MICROSECOND", 486 "USECONDS": "MICROSECOND", 487 "MICROSECONDS": "MICROSECOND", 488 "NS": "NANOSECOND", 489 "NSEC": "NANOSECOND", 490 "NANOSEC": "NANOSECOND", 491 "NSECOND": "NANOSECOND", 492 "NSECONDS": "NANOSECOND", 493 "NANOSECS": "NANOSECOND", 494 "EPOCH_SECOND": "EPOCH", 495 "EPOCH_SECONDS": "EPOCH", 496 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 497 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 498 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 499 "TZH": "TIMEZONE_HOUR", 500 "TZM": "TIMEZONE_MINUTE", 501 "DEC": "DECADE", 502 "DECS": "DECADE", 503 "DECADES": "DECADE", 504 "MIL": "MILLENIUM", 505 "MILS": "MILLENIUM", 506 "MILLENIA": "MILLENIUM", 507 "C": "CENTURY", 508 "CENT": "CENTURY", 509 "CENTS": "CENTURY", 510 "CENTURIES": "CENTURY", 511 } 512 513 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 514 exp.DataType.Type.BIGINT: { 515 exp.ApproxDistinct, 516 exp.ArraySize, 517 exp.Count, 518 exp.Length, 519 }, 520 exp.DataType.Type.BOOLEAN: { 521 exp.Between, 522 exp.Boolean, 523 exp.In, 524 exp.RegexpLike, 525 }, 526 exp.DataType.Type.DATE: { 527 exp.CurrentDate, 528 exp.Date, 529 exp.DateFromParts, 530 exp.DateStrToDate, 531 exp.DiToDate, 532 exp.StrToDate, 533 exp.TimeStrToDate, 534 exp.TsOrDsToDate, 535 }, 536 exp.DataType.Type.DATETIME: { 537 exp.CurrentDatetime, 538 exp.Datetime, 539 exp.DatetimeAdd, 540 exp.DatetimeSub, 541 }, 542 exp.DataType.Type.DOUBLE: { 543 exp.ApproxQuantile, 544 exp.Avg, 545 exp.Div, 546 exp.Exp, 547 exp.Ln, 548 exp.Log, 549 exp.Pow, 550 exp.Quantile, 551 exp.Round, 552 exp.SafeDivide, 553 exp.Sqrt, 554 exp.Stddev, 555 exp.StddevPop, 556 exp.StddevSamp, 557 exp.Variance, 558 exp.VariancePop, 559 }, 560 exp.DataType.Type.INT: { 561 exp.Ceil, 562 exp.DatetimeDiff, 563 exp.DateDiff, 564 exp.TimestampDiff, 565 exp.TimeDiff, 566 exp.DateToDi, 567 exp.Levenshtein, 568 exp.Sign, 569 exp.StrPosition, 570 exp.TsOrDiToDi, 571 }, 572 exp.DataType.Type.JSON: { 573 exp.ParseJSON, 574 }, 575 exp.DataType.Type.TIME: { 576 exp.Time, 577 }, 578 exp.DataType.Type.TIMESTAMP: { 579 exp.CurrentTime, 580 exp.CurrentTimestamp, 581 exp.StrToTime, 582 exp.TimeAdd, 583 exp.TimeStrToTime, 584 exp.TimeSub, 585 exp.TimestampAdd, 586 exp.TimestampSub, 587 exp.UnixToTime, 588 }, 589 exp.DataType.Type.TINYINT: { 590 exp.Day, 591 exp.Month, 592 exp.Week, 593 exp.Year, 594 exp.Quarter, 595 }, 596 exp.DataType.Type.VARCHAR: { 597 exp.ArrayConcat, 598 exp.Concat, 599 exp.ConcatWs, 600 exp.DateToDateStr, 601 exp.GroupConcat, 602 exp.Initcap, 603 exp.Lower, 604 exp.Substring, 605 exp.TimeToStr, 606 exp.TimeToTimeStr, 607 exp.Trim, 608 exp.TsOrDsToDateStr, 609 exp.UnixToStr, 610 exp.UnixToTimeStr, 611 exp.Upper, 612 }, 613 } 614 615 ANNOTATORS: AnnotatorsType = { 616 **{ 617 expr_type: lambda self, e: self._annotate_unary(e) 618 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 619 }, 620 **{ 621 expr_type: lambda self, e: self._annotate_binary(e) 622 for expr_type in subclasses(exp.__name__, exp.Binary) 623 }, 624 **{ 625 expr_type: _annotate_with_type_lambda(data_type) 626 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 627 for expr_type in expressions 628 }, 629 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 630 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 631 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 632 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 633 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 634 exp.Bracket: lambda self, e: self._annotate_bracket(e), 635 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 636 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 637 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 638 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 639 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 640 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 641 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 642 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 643 exp.Div: lambda self, e: self._annotate_div(e), 644 exp.Dot: lambda self, e: self._annotate_dot(e), 645 exp.Explode: lambda self, e: self._annotate_explode(e), 646 exp.Extract: lambda self, e: self._annotate_extract(e), 647 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 648 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 649 e, exp.DataType.build("ARRAY<DATE>") 650 ), 651 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 652 e, exp.DataType.build("ARRAY<TIMESTAMP>") 653 ), 654 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 655 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 656 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 657 exp.Literal: lambda self, e: self._annotate_literal(e), 658 exp.Map: lambda self, e: self._annotate_map(e), 659 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 660 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 661 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 662 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 663 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 664 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 665 exp.Struct: lambda self, e: self._annotate_struct(e), 666 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 667 exp.Timestamp: lambda self, e: self._annotate_with_type( 668 e, 669 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 670 ), 671 exp.ToMap: lambda self, e: self._annotate_to_map(e), 672 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 673 exp.Unnest: lambda self, e: self._annotate_unnest(e), 674 exp.VarMap: lambda self, e: self._annotate_map(e), 675 } 676 677 @classmethod 678 def get_or_raise(cls, dialect: DialectType) -> Dialect: 679 """ 680 Look up a dialect in the global dialect registry and return it if it exists. 681 682 Args: 683 dialect: The target dialect. If this is a string, it can be optionally followed by 684 additional key-value pairs that are separated by commas and are used to specify 685 dialect settings, such as whether the dialect's identifiers are case-sensitive. 686 687 Example: 688 >>> dialect = dialect_class = get_or_raise("duckdb") 689 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 690 691 Returns: 692 The corresponding Dialect instance. 693 """ 694 695 if not dialect: 696 return cls() 697 if isinstance(dialect, _Dialect): 698 return dialect() 699 if isinstance(dialect, Dialect): 700 return dialect 701 if isinstance(dialect, str): 702 try: 703 dialect_name, *kv_strings = dialect.split(",") 704 kv_pairs = (kv.split("=") for kv in kv_strings) 705 kwargs = {} 706 for pair in kv_pairs: 707 key = pair[0].strip() 708 value: t.Union[bool | str | None] = None 709 710 if len(pair) == 1: 711 # Default initialize standalone settings to True 712 value = True 713 elif len(pair) == 2: 714 value = pair[1].strip() 715 716 # Coerce the value to boolean if it matches to the truthy/falsy values below 717 value_lower = value.lower() 718 if value_lower in ("true", "1"): 719 value = True 720 elif value_lower in ("false", "0"): 721 value = False 722 723 kwargs[key] = value 724 725 except ValueError: 726 raise ValueError( 727 f"Invalid dialect format: '{dialect}'. " 728 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 729 ) 730 731 result = cls.get(dialect_name.strip()) 732 if not result: 733 from difflib import get_close_matches 734 735 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 736 if similar: 737 similar = f" Did you mean {similar}?" 738 739 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 740 741 return result(**kwargs) 742 743 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 744 745 @classmethod 746 def format_time( 747 cls, expression: t.Optional[str | exp.Expression] 748 ) -> t.Optional[exp.Expression]: 749 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 750 if isinstance(expression, str): 751 return exp.Literal.string( 752 # the time formats are quoted 753 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 754 ) 755 756 if expression and expression.is_string: 757 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 758 759 return expression 760 761 def __init__(self, **kwargs) -> None: 762 normalization_strategy = kwargs.pop("normalization_strategy", None) 763 764 if normalization_strategy is None: 765 self.normalization_strategy = self.NORMALIZATION_STRATEGY 766 else: 767 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 768 769 self.settings = kwargs 770 771 def __eq__(self, other: t.Any) -> bool: 772 # Does not currently take dialect state into account 773 return type(self) == other 774 775 def __hash__(self) -> int: 776 # Does not currently take dialect state into account 777 return hash(type(self)) 778 779 def normalize_identifier(self, expression: E) -> E: 780 """ 781 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 782 783 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 784 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 785 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 786 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 787 788 There are also dialects like Spark, which are case-insensitive even when quotes are 789 present, and dialects like MySQL, whose resolution rules match those employed by the 790 underlying operating system, for example they may always be case-sensitive in Linux. 791 792 Finally, the normalization behavior of some engines can even be controlled through flags, 793 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 794 795 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 796 that it can analyze queries in the optimizer and successfully capture their semantics. 797 """ 798 if ( 799 isinstance(expression, exp.Identifier) 800 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 801 and ( 802 not expression.quoted 803 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 804 ) 805 ): 806 expression.set( 807 "this", 808 ( 809 expression.this.upper() 810 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 811 else expression.this.lower() 812 ), 813 ) 814 815 return expression 816 817 def case_sensitive(self, text: str) -> bool: 818 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 819 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 820 return False 821 822 unsafe = ( 823 str.islower 824 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 825 else str.isupper 826 ) 827 return any(unsafe(char) for char in text) 828 829 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 830 """Checks if text can be identified given an identify option. 831 832 Args: 833 text: The text to check. 834 identify: 835 `"always"` or `True`: Always returns `True`. 836 `"safe"`: Only returns `True` if the identifier is case-insensitive. 837 838 Returns: 839 Whether the given text can be identified. 840 """ 841 if identify is True or identify == "always": 842 return True 843 844 if identify == "safe": 845 return not self.case_sensitive(text) 846 847 return False 848 849 def quote_identifier(self, expression: E, identify: bool = True) -> E: 850 """ 851 Adds quotes to a given identifier. 852 853 Args: 854 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 855 identify: If set to `False`, the quotes will only be added if the identifier is deemed 856 "unsafe", with respect to its characters and this dialect's normalization strategy. 857 """ 858 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 859 name = expression.this 860 expression.set( 861 "quoted", 862 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 863 ) 864 865 return expression 866 867 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 868 if isinstance(path, exp.Literal): 869 path_text = path.name 870 if path.is_number: 871 path_text = f"[{path_text}]" 872 try: 873 return parse_json_path(path_text, self) 874 except ParseError as e: 875 logger.warning(f"Invalid JSON path syntax. {str(e)}") 876 877 return path 878 879 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 880 return self.parser(**opts).parse(self.tokenize(sql), sql) 881 882 def parse_into( 883 self, expression_type: exp.IntoType, sql: str, **opts 884 ) -> t.List[t.Optional[exp.Expression]]: 885 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 886 887 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 888 return self.generator(**opts).generate(expression, copy=copy) 889 890 def transpile(self, sql: str, **opts) -> t.List[str]: 891 return [ 892 self.generate(expression, copy=False, **opts) if expression else "" 893 for expression in self.parse(sql) 894 ] 895 896 def tokenize(self, sql: str) -> t.List[Token]: 897 return self.tokenizer.tokenize(sql) 898 899 @property 900 def tokenizer(self) -> Tokenizer: 901 return self.tokenizer_class(dialect=self) 902 903 @property 904 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 905 return self.jsonpath_tokenizer_class(dialect=self) 906 907 def parser(self, **opts) -> Parser: 908 return self.parser_class(dialect=self, **opts) 909 910 def generator(self, **opts) -> Generator: 911 return self.generator_class(dialect=self, **opts)
761 def __init__(self, **kwargs) -> None: 762 normalization_strategy = kwargs.pop("normalization_strategy", None) 763 764 if normalization_strategy is None: 765 self.normalization_strategy = self.NORMALIZATION_STRATEGY 766 else: 767 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 768 769 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
677 @classmethod 678 def get_or_raise(cls, dialect: DialectType) -> Dialect: 679 """ 680 Look up a dialect in the global dialect registry and return it if it exists. 681 682 Args: 683 dialect: The target dialect. If this is a string, it can be optionally followed by 684 additional key-value pairs that are separated by commas and are used to specify 685 dialect settings, such as whether the dialect's identifiers are case-sensitive. 686 687 Example: 688 >>> dialect = dialect_class = get_or_raise("duckdb") 689 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 690 691 Returns: 692 The corresponding Dialect instance. 693 """ 694 695 if not dialect: 696 return cls() 697 if isinstance(dialect, _Dialect): 698 return dialect() 699 if isinstance(dialect, Dialect): 700 return dialect 701 if isinstance(dialect, str): 702 try: 703 dialect_name, *kv_strings = dialect.split(",") 704 kv_pairs = (kv.split("=") for kv in kv_strings) 705 kwargs = {} 706 for pair in kv_pairs: 707 key = pair[0].strip() 708 value: t.Union[bool | str | None] = None 709 710 if len(pair) == 1: 711 # Default initialize standalone settings to True 712 value = True 713 elif len(pair) == 2: 714 value = pair[1].strip() 715 716 # Coerce the value to boolean if it matches to the truthy/falsy values below 717 value_lower = value.lower() 718 if value_lower in ("true", "1"): 719 value = True 720 elif value_lower in ("false", "0"): 721 value = False 722 723 kwargs[key] = value 724 725 except ValueError: 726 raise ValueError( 727 f"Invalid dialect format: '{dialect}'. " 728 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 729 ) 730 731 result = cls.get(dialect_name.strip()) 732 if not result: 733 from difflib import get_close_matches 734 735 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 736 if similar: 737 similar = f" Did you mean {similar}?" 738 739 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 740 741 return result(**kwargs) 742 743 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
745 @classmethod 746 def format_time( 747 cls, expression: t.Optional[str | exp.Expression] 748 ) -> t.Optional[exp.Expression]: 749 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 750 if isinstance(expression, str): 751 return exp.Literal.string( 752 # the time formats are quoted 753 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 754 ) 755 756 if expression and expression.is_string: 757 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 758 759 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
779 def normalize_identifier(self, expression: E) -> E: 780 """ 781 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 782 783 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 784 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 785 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 786 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 787 788 There are also dialects like Spark, which are case-insensitive even when quotes are 789 present, and dialects like MySQL, whose resolution rules match those employed by the 790 underlying operating system, for example they may always be case-sensitive in Linux. 791 792 Finally, the normalization behavior of some engines can even be controlled through flags, 793 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 794 795 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 796 that it can analyze queries in the optimizer and successfully capture their semantics. 797 """ 798 if ( 799 isinstance(expression, exp.Identifier) 800 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 801 and ( 802 not expression.quoted 803 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 804 ) 805 ): 806 expression.set( 807 "this", 808 ( 809 expression.this.upper() 810 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 811 else expression.this.lower() 812 ), 813 ) 814 815 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
817 def case_sensitive(self, text: str) -> bool: 818 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 819 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 820 return False 821 822 unsafe = ( 823 str.islower 824 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 825 else str.isupper 826 ) 827 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
829 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 830 """Checks if text can be identified given an identify option. 831 832 Args: 833 text: The text to check. 834 identify: 835 `"always"` or `True`: Always returns `True`. 836 `"safe"`: Only returns `True` if the identifier is case-insensitive. 837 838 Returns: 839 Whether the given text can be identified. 840 """ 841 if identify is True or identify == "always": 842 return True 843 844 if identify == "safe": 845 return not self.case_sensitive(text) 846 847 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
849 def quote_identifier(self, expression: E, identify: bool = True) -> E: 850 """ 851 Adds quotes to a given identifier. 852 853 Args: 854 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 855 identify: If set to `False`, the quotes will only be added if the identifier is deemed 856 "unsafe", with respect to its characters and this dialect's normalization strategy. 857 """ 858 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 859 name = expression.this 860 expression.set( 861 "quoted", 862 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 863 ) 864 865 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
867 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 868 if isinstance(path, exp.Literal): 869 path_text = path.name 870 if path.is_number: 871 path_text = f"[{path_text}]" 872 try: 873 return parse_json_path(path_text, self) 874 except ParseError as e: 875 logger.warning(f"Invalid JSON path syntax. {str(e)}") 876 877 return path
927def if_sql( 928 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 929) -> t.Callable[[Generator, exp.If], str]: 930 def _if_sql(self: Generator, expression: exp.If) -> str: 931 return self.func( 932 name, 933 expression.this, 934 expression.args.get("true"), 935 expression.args.get("false") or false_value, 936 ) 937 938 return _if_sql
941def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 942 this = expression.this 943 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 944 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 945 946 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1012def str_position_sql( 1013 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1014) -> str: 1015 this = self.sql(expression, "this") 1016 substr = self.sql(expression, "substr") 1017 position = self.sql(expression, "position") 1018 instance = expression.args.get("instance") if generate_instance else None 1019 position_offset = "" 1020 1021 if position: 1022 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1023 this = self.func("SUBSTR", this, position) 1024 position_offset = f" + {position} - 1" 1025 1026 return self.func("STRPOS", this, substr, instance) + position_offset
1035def var_map_sql( 1036 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1037) -> str: 1038 keys = expression.args["keys"] 1039 values = expression.args["values"] 1040 1041 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1042 self.unsupported("Cannot convert array columns into map.") 1043 return self.func(map_func_name, keys, values) 1044 1045 args = [] 1046 for key, value in zip(keys.expressions, values.expressions): 1047 args.append(self.sql(key)) 1048 args.append(self.sql(value)) 1049 1050 return self.func(map_func_name, *args)
1053def build_formatted_time( 1054 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1055) -> t.Callable[[t.List], E]: 1056 """Helper used for time expressions. 1057 1058 Args: 1059 exp_class: the expression class to instantiate. 1060 dialect: target sql dialect. 1061 default: the default format, True being time. 1062 1063 Returns: 1064 A callable that can be used to return the appropriately formatted time expression. 1065 """ 1066 1067 def _builder(args: t.List): 1068 return exp_class( 1069 this=seq_get(args, 0), 1070 format=Dialect[dialect].format_time( 1071 seq_get(args, 1) 1072 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1073 ), 1074 ) 1075 1076 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1079def time_format( 1080 dialect: DialectType = None, 1081) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1082 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1083 """ 1084 Returns the time format for a given expression, unless it's equivalent 1085 to the default time format of the dialect of interest. 1086 """ 1087 time_format = self.format_time(expression) 1088 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1089 1090 return _time_format
1093def build_date_delta( 1094 exp_class: t.Type[E], 1095 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1096 default_unit: t.Optional[str] = "DAY", 1097) -> t.Callable[[t.List], E]: 1098 def _builder(args: t.List) -> E: 1099 unit_based = len(args) == 3 1100 this = args[2] if unit_based else seq_get(args, 0) 1101 unit = None 1102 if unit_based or default_unit: 1103 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1104 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1105 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1106 1107 return _builder
1110def build_date_delta_with_interval( 1111 expression_class: t.Type[E], 1112) -> t.Callable[[t.List], t.Optional[E]]: 1113 def _builder(args: t.List) -> t.Optional[E]: 1114 if len(args) < 2: 1115 return None 1116 1117 interval = args[1] 1118 1119 if not isinstance(interval, exp.Interval): 1120 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1121 1122 expression = interval.this 1123 if expression and expression.is_string: 1124 expression = exp.Literal.number(expression.this) 1125 1126 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1127 1128 return _builder
1131def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1132 unit = seq_get(args, 0) 1133 this = seq_get(args, 1) 1134 1135 if isinstance(this, exp.Cast) and this.is_type("date"): 1136 return exp.DateTrunc(unit=unit, this=this) 1137 return exp.TimestampTrunc(this=this, unit=unit)
1140def date_add_interval_sql( 1141 data_type: str, kind: str 1142) -> t.Callable[[Generator, exp.Expression], str]: 1143 def func(self: Generator, expression: exp.Expression) -> str: 1144 this = self.sql(expression, "this") 1145 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1146 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1147 1148 return func
1151def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1152 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1153 args = [unit_to_str(expression), expression.this] 1154 if zone: 1155 args.append(expression.args.get("zone")) 1156 return self.func("DATE_TRUNC", *args) 1157 1158 return _timestamptrunc_sql
1161def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1162 zone = expression.args.get("zone") 1163 if not zone: 1164 from sqlglot.optimizer.annotate_types import annotate_types 1165 1166 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1167 return self.sql(exp.cast(expression.this, target_type)) 1168 if zone.name.lower() in TIMEZONES: 1169 return self.sql( 1170 exp.AtTimeZone( 1171 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1172 zone=zone, 1173 ) 1174 ) 1175 return self.func("TIMESTAMP", expression.this, zone)
1178def no_time_sql(self: Generator, expression: exp.Time) -> str: 1179 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1180 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1181 expr = exp.cast( 1182 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1183 ) 1184 return self.sql(expr)
1187def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1188 this = expression.this 1189 expr = expression.expression 1190 1191 if expr.name.lower() in TIMEZONES: 1192 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1193 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1194 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1195 return self.sql(this) 1196 1197 this = exp.cast(this, exp.DataType.Type.DATE) 1198 expr = exp.cast(expr, exp.DataType.Type.TIME) 1199 1200 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1241def encode_decode_sql( 1242 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1243) -> str: 1244 charset = expression.args.get("charset") 1245 if charset and charset.name.lower() != "utf-8": 1246 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1247 1248 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1261def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1262 cond = expression.this 1263 1264 if isinstance(expression.this, exp.Distinct): 1265 cond = expression.this.expressions[0] 1266 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1267 1268 return self.func("sum", exp.func("if", cond, 1, 0))
1271def trim_sql(self: Generator, expression: exp.Trim) -> str: 1272 target = self.sql(expression, "this") 1273 trim_type = self.sql(expression, "position") 1274 remove_chars = self.sql(expression, "expression") 1275 collation = self.sql(expression, "collation") 1276 1277 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1278 if not remove_chars and not collation: 1279 return self.trim_sql(expression) 1280 1281 trim_type = f"{trim_type} " if trim_type else "" 1282 remove_chars = f"{remove_chars} " if remove_chars else "" 1283 from_part = "FROM " if trim_type or remove_chars else "" 1284 collation = f" COLLATE {collation}" if collation else "" 1285 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1306def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1307 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1308 if bad_args: 1309 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1310 1311 return self.func( 1312 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1313 )
1316def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1317 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1318 if bad_args: 1319 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1320 1321 return self.func( 1322 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1323 )
1326def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1327 names = [] 1328 for agg in aggregations: 1329 if isinstance(agg, exp.Alias): 1330 names.append(agg.alias) 1331 else: 1332 """ 1333 This case corresponds to aggregations without aliases being used as suffixes 1334 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1335 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1336 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1337 """ 1338 agg_all_unquoted = agg.transform( 1339 lambda node: ( 1340 exp.Identifier(this=node.name, quoted=False) 1341 if isinstance(node, exp.Identifier) 1342 else node 1343 ) 1344 ) 1345 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1346 1347 return names
1387def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1388 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1389 if expression.args.get("count"): 1390 self.unsupported(f"Only two arguments are supported in function {name}.") 1391 1392 return self.func(name, expression.this, expression.expression) 1393 1394 return _arg_max_or_min_sql
1397def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1398 this = expression.this.copy() 1399 1400 return_type = expression.return_type 1401 if return_type.is_type(exp.DataType.Type.DATE): 1402 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1403 # can truncate timestamp strings, because some dialects can't cast them to DATE 1404 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1405 1406 expression.this.replace(exp.cast(this, return_type)) 1407 return expression
1410def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1411 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1412 if cast and isinstance(expression, exp.TsOrDsAdd): 1413 expression = ts_or_ds_add_cast(expression) 1414 1415 return self.func( 1416 name, 1417 unit_to_var(expression), 1418 expression.expression, 1419 expression.this, 1420 ) 1421 1422 return _delta_sql
1425def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1426 unit = expression.args.get("unit") 1427 1428 if isinstance(unit, exp.Placeholder): 1429 return unit 1430 if unit: 1431 return exp.Literal.string(unit.name) 1432 return exp.Literal.string(default) if default else None
1462def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1463 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1464 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1465 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1466 1467 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1470def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1471 """Remove table refs from columns in when statements.""" 1472 alias = expression.this.args.get("alias") 1473 1474 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1475 return self.dialect.normalize_identifier(identifier).name if identifier else None 1476 1477 targets = {normalize(expression.this.this)} 1478 1479 if alias: 1480 targets.add(normalize(alias.this)) 1481 1482 for when in expression.expressions: 1483 when.transform( 1484 lambda node: ( 1485 exp.column(node.this) 1486 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1487 else node 1488 ), 1489 copy=False, 1490 ) 1491 1492 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1495def build_json_extract_path( 1496 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1497) -> t.Callable[[t.List], F]: 1498 def _builder(args: t.List) -> F: 1499 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1500 for arg in args[1:]: 1501 if not isinstance(arg, exp.Literal): 1502 # We use the fallback parser because we can't really transpile non-literals safely 1503 return expr_type.from_arg_list(args) 1504 1505 text = arg.name 1506 if is_int(text): 1507 index = int(text) 1508 segments.append( 1509 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1510 ) 1511 else: 1512 segments.append(exp.JSONPathKey(this=text)) 1513 1514 # This is done to avoid failing in the expression validator due to the arg count 1515 del args[2:] 1516 return expr_type( 1517 this=seq_get(args, 0), 1518 expression=exp.JSONPath(expressions=segments), 1519 only_json_types=arrow_req_json_type, 1520 ) 1521 1522 return _builder
1525def json_extract_segments( 1526 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1527) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1528 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1529 path = expression.expression 1530 if not isinstance(path, exp.JSONPath): 1531 return rename_func(name)(self, expression) 1532 1533 segments = [] 1534 for segment in path.expressions: 1535 path = self.sql(segment) 1536 if path: 1537 if isinstance(segment, exp.JSONPathPart) and ( 1538 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1539 ): 1540 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1541 1542 segments.append(path) 1543 1544 if op: 1545 return f" {op} ".join([self.sql(expression.this), *segments]) 1546 return self.func(name, expression.this, *segments) 1547 1548 return _json_extract_segments
1558def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1559 cond = expression.expression 1560 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1561 alias = cond.expressions[0] 1562 cond = cond.this 1563 elif isinstance(cond, exp.Predicate): 1564 alias = "_u" 1565 else: 1566 self.unsupported("Unsupported filter condition") 1567 return "" 1568 1569 unnest = exp.Unnest(expressions=[expression.this]) 1570 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1571 return self.sql(exp.Array(expressions=[filtered]))
1583def build_default_decimal_type( 1584 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1585) -> t.Callable[[exp.DataType], exp.DataType]: 1586 def _builder(dtype: exp.DataType) -> exp.DataType: 1587 if dtype.expressions or precision is None: 1588 return dtype 1589 1590 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1591 return exp.DataType.build(f"DECIMAL({params})") 1592 1593 return _builder
1596def build_timestamp_from_parts(args: t.List) -> exp.Func: 1597 if len(args) == 2: 1598 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1599 # so we parse this into Anonymous for now instead of introducing complexity 1600 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1601 1602 return exp.TimestampFromParts.from_arg_list(args)
1609def sequence_sql(self: Generator, expression: exp.GenerateSeries): 1610 start = expression.args["start"] 1611 end = expression.args["end"] 1612 step = expression.args.get("step") 1613 1614 if isinstance(start, exp.Cast): 1615 target_type = start.to 1616 elif isinstance(end, exp.Cast): 1617 target_type = end.to 1618 else: 1619 target_type = None 1620 1621 if target_type and target_type.is_type("timestamp"): 1622 if target_type is start.to: 1623 end = exp.cast(end, target_type) 1624 else: 1625 start = exp.cast(start, target_type) 1626 1627 return self.func("SEQUENCE", start, end, step)