Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 114
 115        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 116        klass.parser_class = getattr(klass, "Parser", Parser)
 117        klass.generator_class = getattr(klass, "Generator", Generator)
 118
 119        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 120        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 121            klass.tokenizer_class._IDENTIFIERS.items()
 122        )[0]
 123
 124        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 125            return next(
 126                (
 127                    (s, e)
 128                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 129                    if t == token_type
 130                ),
 131                (None, None),
 132            )
 133
 134        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 135        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 136        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 137        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 138
 139        if enum not in ("", "bigquery"):
 140            klass.generator_class.SELECT_KINDS = ()
 141
 142        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 143            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 144            for modifier in ("cluster", "distribute", "sort"):
 145                modifier_transforms.pop(modifier, None)
 146
 147            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 148
 149        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 150            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 151                TokenType.ANTI,
 152                TokenType.SEMI,
 153            }
 154
 155        return klass
 156
 157
 158class Dialect(metaclass=_Dialect):
 159    INDEX_OFFSET = 0
 160    """The base index offset for arrays."""
 161
 162    WEEK_OFFSET = 0
 163    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 164
 165    UNNEST_COLUMN_ONLY = False
 166    """Whether `UNNEST` table aliases are treated as column aliases."""
 167
 168    ALIAS_POST_TABLESAMPLE = False
 169    """Whether the table alias comes after tablesample."""
 170
 171    TABLESAMPLE_SIZE_IS_PERCENT = False
 172    """Whether a size in the table sample clause represents percentage."""
 173
 174    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 175    """Specifies the strategy according to which identifiers should be normalized."""
 176
 177    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 178    """Whether an unquoted identifier can start with a digit."""
 179
 180    DPIPE_IS_STRING_CONCAT = True
 181    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 182
 183    STRICT_STRING_CONCAT = False
 184    """Whether `CONCAT`'s arguments must be strings."""
 185
 186    SUPPORTS_USER_DEFINED_TYPES = True
 187    """Whether user-defined data types are supported."""
 188
 189    SUPPORTS_SEMI_ANTI_JOIN = True
 190    """Whether `SEMI` or `ANTI` joins are supported."""
 191
 192    NORMALIZE_FUNCTIONS: bool | str = "upper"
 193    """
 194    Determines how function names are going to be normalized.
 195    Possible values:
 196        "upper" or True: Convert names to uppercase.
 197        "lower": Convert names to lowercase.
 198        False: Disables function name normalization.
 199    """
 200
 201    LOG_BASE_FIRST = True
 202    """Whether the base comes first in the `LOG` function."""
 203
 204    NULL_ORDERING = "nulls_are_small"
 205    """
 206    Default `NULL` ordering method to use if not explicitly set.
 207    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 208    """
 209
 210    TYPED_DIVISION = False
 211    """
 212    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 213    False means `a / b` is always float division.
 214    True means `a / b` is integer division if both `a` and `b` are integers.
 215    """
 216
 217    SAFE_DIVISION = False
 218    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 219
 220    CONCAT_COALESCE = False
 221    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 222
 223    DATE_FORMAT = "'%Y-%m-%d'"
 224    DATEINT_FORMAT = "'%Y%m%d'"
 225    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 226
 227    TIME_MAPPING: t.Dict[str, str] = {}
 228    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 229
 230    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 231    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 232    FORMAT_MAPPING: t.Dict[str, str] = {}
 233    """
 234    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 235    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 236    """
 237
 238    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 239    """Mapping of an unescaped escape sequence to the corresponding character."""
 240
 241    PSEUDOCOLUMNS: t.Set[str] = set()
 242    """
 243    Columns that are auto-generated by the engine corresponding to this dialect.
 244    For example, such columns may be excluded from `SELECT *` queries.
 245    """
 246
 247    PREFER_CTE_ALIAS_COLUMN = False
 248    """
 249    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 250    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 251    any projection aliases in the subquery.
 252
 253    For example,
 254        WITH y(c) AS (
 255            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 256        ) SELECT c FROM y;
 257
 258        will be rewritten as
 259
 260        WITH y(c) AS (
 261            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 262        ) SELECT c FROM y;
 263    """
 264
 265    # --- Autofilled ---
 266
 267    tokenizer_class = Tokenizer
 268    parser_class = Parser
 269    generator_class = Generator
 270
 271    # A trie of the time_mapping keys
 272    TIME_TRIE: t.Dict = {}
 273    FORMAT_TRIE: t.Dict = {}
 274
 275    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 276    INVERSE_TIME_TRIE: t.Dict = {}
 277
 278    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 279
 280    # Delimiters for string literals and identifiers
 281    QUOTE_START = "'"
 282    QUOTE_END = "'"
 283    IDENTIFIER_START = '"'
 284    IDENTIFIER_END = '"'
 285
 286    # Delimiters for bit, hex, byte and unicode literals
 287    BIT_START: t.Optional[str] = None
 288    BIT_END: t.Optional[str] = None
 289    HEX_START: t.Optional[str] = None
 290    HEX_END: t.Optional[str] = None
 291    BYTE_START: t.Optional[str] = None
 292    BYTE_END: t.Optional[str] = None
 293    UNICODE_START: t.Optional[str] = None
 294    UNICODE_END: t.Optional[str] = None
 295
 296    @classmethod
 297    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 298        """
 299        Look up a dialect in the global dialect registry and return it if it exists.
 300
 301        Args:
 302            dialect: The target dialect. If this is a string, it can be optionally followed by
 303                additional key-value pairs that are separated by commas and are used to specify
 304                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 305
 306        Example:
 307            >>> dialect = dialect_class = get_or_raise("duckdb")
 308            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 309
 310        Returns:
 311            The corresponding Dialect instance.
 312        """
 313
 314        if not dialect:
 315            return cls()
 316        if isinstance(dialect, _Dialect):
 317            return dialect()
 318        if isinstance(dialect, Dialect):
 319            return dialect
 320        if isinstance(dialect, str):
 321            try:
 322                dialect_name, *kv_pairs = dialect.split(",")
 323                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 324            except ValueError:
 325                raise ValueError(
 326                    f"Invalid dialect format: '{dialect}'. "
 327                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 328                )
 329
 330            result = cls.get(dialect_name.strip())
 331            if not result:
 332                from difflib import get_close_matches
 333
 334                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 335                if similar:
 336                    similar = f" Did you mean {similar}?"
 337
 338                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 339
 340            return result(**kwargs)
 341
 342        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 343
 344    @classmethod
 345    def format_time(
 346        cls, expression: t.Optional[str | exp.Expression]
 347    ) -> t.Optional[exp.Expression]:
 348        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 349        if isinstance(expression, str):
 350            return exp.Literal.string(
 351                # the time formats are quoted
 352                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 353            )
 354
 355        if expression and expression.is_string:
 356            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 357
 358        return expression
 359
 360    def __init__(self, **kwargs) -> None:
 361        normalization_strategy = kwargs.get("normalization_strategy")
 362
 363        if normalization_strategy is None:
 364            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 365        else:
 366            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 367
 368    def __eq__(self, other: t.Any) -> bool:
 369        # Does not currently take dialect state into account
 370        return type(self) == other
 371
 372    def __hash__(self) -> int:
 373        # Does not currently take dialect state into account
 374        return hash(type(self))
 375
 376    def normalize_identifier(self, expression: E) -> E:
 377        """
 378        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 379
 380        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 381        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 382        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 383        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 384
 385        There are also dialects like Spark, which are case-insensitive even when quotes are
 386        present, and dialects like MySQL, whose resolution rules match those employed by the
 387        underlying operating system, for example they may always be case-sensitive in Linux.
 388
 389        Finally, the normalization behavior of some engines can even be controlled through flags,
 390        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 391
 392        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 393        that it can analyze queries in the optimizer and successfully capture their semantics.
 394        """
 395        if (
 396            isinstance(expression, exp.Identifier)
 397            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 398            and (
 399                not expression.quoted
 400                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 401            )
 402        ):
 403            expression.set(
 404                "this",
 405                (
 406                    expression.this.upper()
 407                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 408                    else expression.this.lower()
 409                ),
 410            )
 411
 412        return expression
 413
 414    def case_sensitive(self, text: str) -> bool:
 415        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 416        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 417            return False
 418
 419        unsafe = (
 420            str.islower
 421            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 422            else str.isupper
 423        )
 424        return any(unsafe(char) for char in text)
 425
 426    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 427        """Checks if text can be identified given an identify option.
 428
 429        Args:
 430            text: The text to check.
 431            identify:
 432                `"always"` or `True`: Always returns `True`.
 433                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 434
 435        Returns:
 436            Whether the given text can be identified.
 437        """
 438        if identify is True or identify == "always":
 439            return True
 440
 441        if identify == "safe":
 442            return not self.case_sensitive(text)
 443
 444        return False
 445
 446    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 447        """
 448        Adds quotes to a given identifier.
 449
 450        Args:
 451            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 452            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 453                "unsafe", with respect to its characters and this dialect's normalization strategy.
 454        """
 455        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 456            name = expression.this
 457            expression.set(
 458                "quoted",
 459                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 460            )
 461
 462        return expression
 463
 464    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 465        if isinstance(path, exp.Literal):
 466            path_text = path.name
 467            if path.is_number:
 468                path_text = f"[{path_text}]"
 469
 470            try:
 471                return parse_json_path(path_text)
 472            except ParseError as e:
 473                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 474
 475        return path
 476
 477    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 478        return self.parser(**opts).parse(self.tokenize(sql), sql)
 479
 480    def parse_into(
 481        self, expression_type: exp.IntoType, sql: str, **opts
 482    ) -> t.List[t.Optional[exp.Expression]]:
 483        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 484
 485    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 486        return self.generator(**opts).generate(expression, copy=copy)
 487
 488    def transpile(self, sql: str, **opts) -> t.List[str]:
 489        return [
 490            self.generate(expression, copy=False, **opts) if expression else ""
 491            for expression in self.parse(sql)
 492        ]
 493
 494    def tokenize(self, sql: str) -> t.List[Token]:
 495        return self.tokenizer.tokenize(sql)
 496
 497    @property
 498    def tokenizer(self) -> Tokenizer:
 499        if not hasattr(self, "_tokenizer"):
 500            self._tokenizer = self.tokenizer_class(dialect=self)
 501        return self._tokenizer
 502
 503    def parser(self, **opts) -> Parser:
 504        return self.parser_class(dialect=self, **opts)
 505
 506    def generator(self, **opts) -> Generator:
 507        return self.generator_class(dialect=self, **opts)
 508
 509
 510DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 511
 512
 513def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 514    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 515
 516
 517def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 518    if expression.args.get("accuracy"):
 519        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 520    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 521
 522
 523def if_sql(
 524    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 525) -> t.Callable[[Generator, exp.If], str]:
 526    def _if_sql(self: Generator, expression: exp.If) -> str:
 527        return self.func(
 528            name,
 529            expression.this,
 530            expression.args.get("true"),
 531            expression.args.get("false") or false_value,
 532        )
 533
 534    return _if_sql
 535
 536
 537def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 538    this = expression.this
 539    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 540        this.replace(exp.cast(this, "json"))
 541
 542    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 543
 544
 545def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 546    return f"[{self.expressions(expression, flat=True)}]"
 547
 548
 549def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 550    return self.like_sql(
 551        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 552    )
 553
 554
 555def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 556    zone = self.sql(expression, "this")
 557    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 558
 559
 560def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 561    if expression.args.get("recursive"):
 562        self.unsupported("Recursive CTEs are unsupported")
 563        expression.args["recursive"] = False
 564    return self.with_sql(expression)
 565
 566
 567def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 568    n = self.sql(expression, "this")
 569    d = self.sql(expression, "expression")
 570    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 571
 572
 573def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 574    self.unsupported("TABLESAMPLE unsupported")
 575    return self.sql(expression.this)
 576
 577
 578def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 579    self.unsupported("PIVOT unsupported")
 580    return ""
 581
 582
 583def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 584    return self.cast_sql(expression)
 585
 586
 587def no_comment_column_constraint_sql(
 588    self: Generator, expression: exp.CommentColumnConstraint
 589) -> str:
 590    self.unsupported("CommentColumnConstraint unsupported")
 591    return ""
 592
 593
 594def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 595    self.unsupported("MAP_FROM_ENTRIES unsupported")
 596    return ""
 597
 598
 599def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 600    this = self.sql(expression, "this")
 601    substr = self.sql(expression, "substr")
 602    position = self.sql(expression, "position")
 603    if position:
 604        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 605    return f"STRPOS({this}, {substr})"
 606
 607
 608def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 609    return (
 610        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 611    )
 612
 613
 614def var_map_sql(
 615    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 616) -> str:
 617    keys = expression.args["keys"]
 618    values = expression.args["values"]
 619
 620    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 621        self.unsupported("Cannot convert array columns into map.")
 622        return self.func(map_func_name, keys, values)
 623
 624    args = []
 625    for key, value in zip(keys.expressions, values.expressions):
 626        args.append(self.sql(key))
 627        args.append(self.sql(value))
 628
 629    return self.func(map_func_name, *args)
 630
 631
 632def build_formatted_time(
 633    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 634) -> t.Callable[[t.List], E]:
 635    """Helper used for time expressions.
 636
 637    Args:
 638        exp_class: the expression class to instantiate.
 639        dialect: target sql dialect.
 640        default: the default format, True being time.
 641
 642    Returns:
 643        A callable that can be used to return the appropriately formatted time expression.
 644    """
 645
 646    def _builder(args: t.List):
 647        return exp_class(
 648            this=seq_get(args, 0),
 649            format=Dialect[dialect].format_time(
 650                seq_get(args, 1)
 651                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 652            ),
 653        )
 654
 655    return _builder
 656
 657
 658def time_format(
 659    dialect: DialectType = None,
 660) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 661    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 662        """
 663        Returns the time format for a given expression, unless it's equivalent
 664        to the default time format of the dialect of interest.
 665        """
 666        time_format = self.format_time(expression)
 667        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 668
 669    return _time_format
 670
 671
 672def build_date_delta(
 673    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 674) -> t.Callable[[t.List], E]:
 675    def _builder(args: t.List) -> E:
 676        unit_based = len(args) == 3
 677        this = args[2] if unit_based else seq_get(args, 0)
 678        unit = args[0] if unit_based else exp.Literal.string("DAY")
 679        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 680        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 681
 682    return _builder
 683
 684
 685def build_date_delta_with_interval(
 686    expression_class: t.Type[E],
 687) -> t.Callable[[t.List], t.Optional[E]]:
 688    def _builder(args: t.List) -> t.Optional[E]:
 689        if len(args) < 2:
 690            return None
 691
 692        interval = args[1]
 693
 694        if not isinstance(interval, exp.Interval):
 695            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 696
 697        expression = interval.this
 698        if expression and expression.is_string:
 699            expression = exp.Literal.number(expression.this)
 700
 701        return expression_class(
 702            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 703        )
 704
 705    return _builder
 706
 707
 708def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 709    unit = seq_get(args, 0)
 710    this = seq_get(args, 1)
 711
 712    if isinstance(this, exp.Cast) and this.is_type("date"):
 713        return exp.DateTrunc(unit=unit, this=this)
 714    return exp.TimestampTrunc(this=this, unit=unit)
 715
 716
 717def date_add_interval_sql(
 718    data_type: str, kind: str
 719) -> t.Callable[[Generator, exp.Expression], str]:
 720    def func(self: Generator, expression: exp.Expression) -> str:
 721        this = self.sql(expression, "this")
 722        unit = expression.args.get("unit")
 723        unit = exp.var(unit.name.upper() if unit else "DAY")
 724        interval = exp.Interval(this=expression.expression, unit=unit)
 725        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 726
 727    return func
 728
 729
 730def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 731    return self.func(
 732        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 733    )
 734
 735
 736def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 737    if not expression.expression:
 738        from sqlglot.optimizer.annotate_types import annotate_types
 739
 740        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 741        return self.sql(exp.cast(expression.this, to=target_type))
 742    if expression.text("expression").lower() in TIMEZONES:
 743        return self.sql(
 744            exp.AtTimeZone(
 745                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 746                zone=expression.expression,
 747            )
 748        )
 749    return self.func("TIMESTAMP", expression.this, expression.expression)
 750
 751
 752def locate_to_strposition(args: t.List) -> exp.Expression:
 753    return exp.StrPosition(
 754        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 755    )
 756
 757
 758def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 759    return self.func(
 760        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 761    )
 762
 763
 764def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 765    return self.sql(
 766        exp.Substring(
 767            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 768        )
 769    )
 770
 771
 772def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 773    return self.sql(
 774        exp.Substring(
 775            this=expression.this,
 776            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 777        )
 778    )
 779
 780
 781def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 782    return self.sql(exp.cast(expression.this, "timestamp"))
 783
 784
 785def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 786    return self.sql(exp.cast(expression.this, "date"))
 787
 788
 789# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 790def encode_decode_sql(
 791    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 792) -> str:
 793    charset = expression.args.get("charset")
 794    if charset and charset.name.lower() != "utf-8":
 795        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 796
 797    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 798
 799
 800def min_or_least(self: Generator, expression: exp.Min) -> str:
 801    name = "LEAST" if expression.expressions else "MIN"
 802    return rename_func(name)(self, expression)
 803
 804
 805def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 806    name = "GREATEST" if expression.expressions else "MAX"
 807    return rename_func(name)(self, expression)
 808
 809
 810def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 811    cond = expression.this
 812
 813    if isinstance(expression.this, exp.Distinct):
 814        cond = expression.this.expressions[0]
 815        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 816
 817    return self.func("sum", exp.func("if", cond, 1, 0))
 818
 819
 820def trim_sql(self: Generator, expression: exp.Trim) -> str:
 821    target = self.sql(expression, "this")
 822    trim_type = self.sql(expression, "position")
 823    remove_chars = self.sql(expression, "expression")
 824    collation = self.sql(expression, "collation")
 825
 826    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 827    if not remove_chars and not collation:
 828        return self.trim_sql(expression)
 829
 830    trim_type = f"{trim_type} " if trim_type else ""
 831    remove_chars = f"{remove_chars} " if remove_chars else ""
 832    from_part = "FROM " if trim_type or remove_chars else ""
 833    collation = f" COLLATE {collation}" if collation else ""
 834    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 835
 836
 837def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 838    return self.func("STRPTIME", expression.this, self.format_time(expression))
 839
 840
 841def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 842    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 843
 844
 845def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 846    delim, *rest_args = expression.expressions
 847    return self.sql(
 848        reduce(
 849            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 850            rest_args,
 851        )
 852    )
 853
 854
 855def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 856    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 857    if bad_args:
 858        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 859
 860    return self.func(
 861        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 862    )
 863
 864
 865def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 866    bad_args = list(
 867        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 868    )
 869    if bad_args:
 870        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 871
 872    return self.func(
 873        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 874    )
 875
 876
 877def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 878    names = []
 879    for agg in aggregations:
 880        if isinstance(agg, exp.Alias):
 881            names.append(agg.alias)
 882        else:
 883            """
 884            This case corresponds to aggregations without aliases being used as suffixes
 885            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 886            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 887            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 888            """
 889            agg_all_unquoted = agg.transform(
 890                lambda node: (
 891                    exp.Identifier(this=node.name, quoted=False)
 892                    if isinstance(node, exp.Identifier)
 893                    else node
 894                )
 895            )
 896            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 897
 898    return names
 899
 900
 901def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 902    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 903
 904
 905# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 906def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 907    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 908
 909
 910def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 911    return self.func("MAX", expression.this)
 912
 913
 914def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 915    a = self.sql(expression.left)
 916    b = self.sql(expression.right)
 917    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 918
 919
 920def is_parse_json(expression: exp.Expression) -> bool:
 921    return isinstance(expression, exp.ParseJSON) or (
 922        isinstance(expression, exp.Cast) and expression.is_type("json")
 923    )
 924
 925
 926def isnull_to_is_null(args: t.List) -> exp.Expression:
 927    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 928
 929
 930def generatedasidentitycolumnconstraint_sql(
 931    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 932) -> str:
 933    start = self.sql(expression, "start") or "1"
 934    increment = self.sql(expression, "increment") or "1"
 935    return f"IDENTITY({start}, {increment})"
 936
 937
 938def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 939    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 940        if expression.args.get("count"):
 941            self.unsupported(f"Only two arguments are supported in function {name}.")
 942
 943        return self.func(name, expression.this, expression.expression)
 944
 945    return _arg_max_or_min_sql
 946
 947
 948def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 949    this = expression.this.copy()
 950
 951    return_type = expression.return_type
 952    if return_type.is_type(exp.DataType.Type.DATE):
 953        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 954        # can truncate timestamp strings, because some dialects can't cast them to DATE
 955        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 956
 957    expression.this.replace(exp.cast(this, return_type))
 958    return expression
 959
 960
 961def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 962    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 963        if cast and isinstance(expression, exp.TsOrDsAdd):
 964            expression = ts_or_ds_add_cast(expression)
 965
 966        return self.func(
 967            name,
 968            exp.var(expression.text("unit").upper() or "DAY"),
 969            expression.expression,
 970            expression.this,
 971        )
 972
 973    return _delta_sql
 974
 975
 976def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 977    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 978    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 979    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 980
 981    return self.sql(exp.cast(minus_one_day, "date"))
 982
 983
 984def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 985    """Remove table refs from columns in when statements."""
 986    alias = expression.this.args.get("alias")
 987
 988    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 989        return self.dialect.normalize_identifier(identifier).name if identifier else None
 990
 991    targets = {normalize(expression.this.this)}
 992
 993    if alias:
 994        targets.add(normalize(alias.this))
 995
 996    for when in expression.expressions:
 997        when.transform(
 998            lambda node: (
 999                exp.column(node.this)
1000                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1001                else node
1002            ),
1003            copy=False,
1004        )
1005
1006    return self.merge_sql(expression)
1007
1008
1009def build_json_extract_path(
1010    expr_type: t.Type[F], zero_based_indexing: bool = True
1011) -> t.Callable[[t.List], F]:
1012    def _builder(args: t.List) -> F:
1013        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1014        for arg in args[1:]:
1015            if not isinstance(arg, exp.Literal):
1016                # We use the fallback parser because we can't really transpile non-literals safely
1017                return expr_type.from_arg_list(args)
1018
1019            text = arg.name
1020            if is_int(text):
1021                index = int(text)
1022                segments.append(
1023                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1024                )
1025            else:
1026                segments.append(exp.JSONPathKey(this=text))
1027
1028        # This is done to avoid failing in the expression validator due to the arg count
1029        del args[2:]
1030        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1031
1032    return _builder
1033
1034
1035def json_extract_segments(
1036    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1037) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1038    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1039        path = expression.expression
1040        if not isinstance(path, exp.JSONPath):
1041            return rename_func(name)(self, expression)
1042
1043        segments = []
1044        for segment in path.expressions:
1045            path = self.sql(segment)
1046            if path:
1047                if isinstance(segment, exp.JSONPathPart) and (
1048                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1049                ):
1050                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1051
1052                segments.append(path)
1053
1054        if op:
1055            return f" {op} ".join([self.sql(expression.this), *segments])
1056        return self.func(name, expression.this, *segments)
1057
1058    return _json_extract_segments
1059
1060
1061def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1062    if isinstance(expression.this, exp.JSONPathWildcard):
1063        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1064
1065    return expression.name
1066
1067
1068def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1069    cond = expression.expression
1070    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1071        alias = cond.expressions[0]
1072        cond = cond.this
1073    elif isinstance(cond, exp.Predicate):
1074        alias = "_u"
1075    else:
1076        self.unsupported("Unsupported filter condition")
1077        return ""
1078
1079    unnest = exp.Unnest(expressions=[expression.this])
1080    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1081    return self.sql(exp.Array(expressions=[filtered]))
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
159class Dialect(metaclass=_Dialect):
160    INDEX_OFFSET = 0
161    """The base index offset for arrays."""
162
163    WEEK_OFFSET = 0
164    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
165
166    UNNEST_COLUMN_ONLY = False
167    """Whether `UNNEST` table aliases are treated as column aliases."""
168
169    ALIAS_POST_TABLESAMPLE = False
170    """Whether the table alias comes after tablesample."""
171
172    TABLESAMPLE_SIZE_IS_PERCENT = False
173    """Whether a size in the table sample clause represents percentage."""
174
175    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
176    """Specifies the strategy according to which identifiers should be normalized."""
177
178    IDENTIFIERS_CAN_START_WITH_DIGIT = False
179    """Whether an unquoted identifier can start with a digit."""
180
181    DPIPE_IS_STRING_CONCAT = True
182    """Whether the DPIPE token (`||`) is a string concatenation operator."""
183
184    STRICT_STRING_CONCAT = False
185    """Whether `CONCAT`'s arguments must be strings."""
186
187    SUPPORTS_USER_DEFINED_TYPES = True
188    """Whether user-defined data types are supported."""
189
190    SUPPORTS_SEMI_ANTI_JOIN = True
191    """Whether `SEMI` or `ANTI` joins are supported."""
192
193    NORMALIZE_FUNCTIONS: bool | str = "upper"
194    """
195    Determines how function names are going to be normalized.
196    Possible values:
197        "upper" or True: Convert names to uppercase.
198        "lower": Convert names to lowercase.
199        False: Disables function name normalization.
200    """
201
202    LOG_BASE_FIRST = True
203    """Whether the base comes first in the `LOG` function."""
204
205    NULL_ORDERING = "nulls_are_small"
206    """
207    Default `NULL` ordering method to use if not explicitly set.
208    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
209    """
210
211    TYPED_DIVISION = False
212    """
213    Whether the behavior of `a / b` depends on the types of `a` and `b`.
214    False means `a / b` is always float division.
215    True means `a / b` is integer division if both `a` and `b` are integers.
216    """
217
218    SAFE_DIVISION = False
219    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
220
221    CONCAT_COALESCE = False
222    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
223
224    DATE_FORMAT = "'%Y-%m-%d'"
225    DATEINT_FORMAT = "'%Y%m%d'"
226    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
227
228    TIME_MAPPING: t.Dict[str, str] = {}
229    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
230
231    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
232    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
233    FORMAT_MAPPING: t.Dict[str, str] = {}
234    """
235    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
236    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
237    """
238
239    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
240    """Mapping of an unescaped escape sequence to the corresponding character."""
241
242    PSEUDOCOLUMNS: t.Set[str] = set()
243    """
244    Columns that are auto-generated by the engine corresponding to this dialect.
245    For example, such columns may be excluded from `SELECT *` queries.
246    """
247
248    PREFER_CTE_ALIAS_COLUMN = False
249    """
250    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
251    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
252    any projection aliases in the subquery.
253
254    For example,
255        WITH y(c) AS (
256            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
257        ) SELECT c FROM y;
258
259        will be rewritten as
260
261        WITH y(c) AS (
262            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
263        ) SELECT c FROM y;
264    """
265
266    # --- Autofilled ---
267
268    tokenizer_class = Tokenizer
269    parser_class = Parser
270    generator_class = Generator
271
272    # A trie of the time_mapping keys
273    TIME_TRIE: t.Dict = {}
274    FORMAT_TRIE: t.Dict = {}
275
276    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
277    INVERSE_TIME_TRIE: t.Dict = {}
278
279    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
280
281    # Delimiters for string literals and identifiers
282    QUOTE_START = "'"
283    QUOTE_END = "'"
284    IDENTIFIER_START = '"'
285    IDENTIFIER_END = '"'
286
287    # Delimiters for bit, hex, byte and unicode literals
288    BIT_START: t.Optional[str] = None
289    BIT_END: t.Optional[str] = None
290    HEX_START: t.Optional[str] = None
291    HEX_END: t.Optional[str] = None
292    BYTE_START: t.Optional[str] = None
293    BYTE_END: t.Optional[str] = None
294    UNICODE_START: t.Optional[str] = None
295    UNICODE_END: t.Optional[str] = None
296
297    @classmethod
298    def get_or_raise(cls, dialect: DialectType) -> Dialect:
299        """
300        Look up a dialect in the global dialect registry and return it if it exists.
301
302        Args:
303            dialect: The target dialect. If this is a string, it can be optionally followed by
304                additional key-value pairs that are separated by commas and are used to specify
305                dialect settings, such as whether the dialect's identifiers are case-sensitive.
306
307        Example:
308            >>> dialect = dialect_class = get_or_raise("duckdb")
309            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
310
311        Returns:
312            The corresponding Dialect instance.
313        """
314
315        if not dialect:
316            return cls()
317        if isinstance(dialect, _Dialect):
318            return dialect()
319        if isinstance(dialect, Dialect):
320            return dialect
321        if isinstance(dialect, str):
322            try:
323                dialect_name, *kv_pairs = dialect.split(",")
324                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
325            except ValueError:
326                raise ValueError(
327                    f"Invalid dialect format: '{dialect}'. "
328                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
329                )
330
331            result = cls.get(dialect_name.strip())
332            if not result:
333                from difflib import get_close_matches
334
335                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
336                if similar:
337                    similar = f" Did you mean {similar}?"
338
339                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
340
341            return result(**kwargs)
342
343        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
344
345    @classmethod
346    def format_time(
347        cls, expression: t.Optional[str | exp.Expression]
348    ) -> t.Optional[exp.Expression]:
349        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
350        if isinstance(expression, str):
351            return exp.Literal.string(
352                # the time formats are quoted
353                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
354            )
355
356        if expression and expression.is_string:
357            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
358
359        return expression
360
361    def __init__(self, **kwargs) -> None:
362        normalization_strategy = kwargs.get("normalization_strategy")
363
364        if normalization_strategy is None:
365            self.normalization_strategy = self.NORMALIZATION_STRATEGY
366        else:
367            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
368
369    def __eq__(self, other: t.Any) -> bool:
370        # Does not currently take dialect state into account
371        return type(self) == other
372
373    def __hash__(self) -> int:
374        # Does not currently take dialect state into account
375        return hash(type(self))
376
377    def normalize_identifier(self, expression: E) -> E:
378        """
379        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
380
381        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
382        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
383        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
384        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
385
386        There are also dialects like Spark, which are case-insensitive even when quotes are
387        present, and dialects like MySQL, whose resolution rules match those employed by the
388        underlying operating system, for example they may always be case-sensitive in Linux.
389
390        Finally, the normalization behavior of some engines can even be controlled through flags,
391        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
392
393        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
394        that it can analyze queries in the optimizer and successfully capture their semantics.
395        """
396        if (
397            isinstance(expression, exp.Identifier)
398            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
399            and (
400                not expression.quoted
401                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
402            )
403        ):
404            expression.set(
405                "this",
406                (
407                    expression.this.upper()
408                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
409                    else expression.this.lower()
410                ),
411            )
412
413        return expression
414
415    def case_sensitive(self, text: str) -> bool:
416        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
417        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
418            return False
419
420        unsafe = (
421            str.islower
422            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
423            else str.isupper
424        )
425        return any(unsafe(char) for char in text)
426
427    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
428        """Checks if text can be identified given an identify option.
429
430        Args:
431            text: The text to check.
432            identify:
433                `"always"` or `True`: Always returns `True`.
434                `"safe"`: Only returns `True` if the identifier is case-insensitive.
435
436        Returns:
437            Whether the given text can be identified.
438        """
439        if identify is True or identify == "always":
440            return True
441
442        if identify == "safe":
443            return not self.case_sensitive(text)
444
445        return False
446
447    def quote_identifier(self, expression: E, identify: bool = True) -> E:
448        """
449        Adds quotes to a given identifier.
450
451        Args:
452            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
453            identify: If set to `False`, the quotes will only be added if the identifier is deemed
454                "unsafe", with respect to its characters and this dialect's normalization strategy.
455        """
456        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
457            name = expression.this
458            expression.set(
459                "quoted",
460                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
461            )
462
463        return expression
464
465    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
466        if isinstance(path, exp.Literal):
467            path_text = path.name
468            if path.is_number:
469                path_text = f"[{path_text}]"
470
471            try:
472                return parse_json_path(path_text)
473            except ParseError as e:
474                logger.warning(f"Invalid JSON path syntax. {str(e)}")
475
476        return path
477
478    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
479        return self.parser(**opts).parse(self.tokenize(sql), sql)
480
481    def parse_into(
482        self, expression_type: exp.IntoType, sql: str, **opts
483    ) -> t.List[t.Optional[exp.Expression]]:
484        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
485
486    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
487        return self.generator(**opts).generate(expression, copy=copy)
488
489    def transpile(self, sql: str, **opts) -> t.List[str]:
490        return [
491            self.generate(expression, copy=False, **opts) if expression else ""
492            for expression in self.parse(sql)
493        ]
494
495    def tokenize(self, sql: str) -> t.List[Token]:
496        return self.tokenizer.tokenize(sql)
497
498    @property
499    def tokenizer(self) -> Tokenizer:
500        if not hasattr(self, "_tokenizer"):
501            self._tokenizer = self.tokenizer_class(dialect=self)
502        return self._tokenizer
503
504    def parser(self, **opts) -> Parser:
505        return self.parser_class(dialect=self, **opts)
506
507    def generator(self, **opts) -> Generator:
508        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
361    def __init__(self, **kwargs) -> None:
362        normalization_strategy = kwargs.get("normalization_strategy")
363
364        if normalization_strategy is None:
365            self.normalization_strategy = self.NORMALIZATION_STRATEGY
366        else:
367            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST = True

Whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
297    @classmethod
298    def get_or_raise(cls, dialect: DialectType) -> Dialect:
299        """
300        Look up a dialect in the global dialect registry and return it if it exists.
301
302        Args:
303            dialect: The target dialect. If this is a string, it can be optionally followed by
304                additional key-value pairs that are separated by commas and are used to specify
305                dialect settings, such as whether the dialect's identifiers are case-sensitive.
306
307        Example:
308            >>> dialect = dialect_class = get_or_raise("duckdb")
309            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
310
311        Returns:
312            The corresponding Dialect instance.
313        """
314
315        if not dialect:
316            return cls()
317        if isinstance(dialect, _Dialect):
318            return dialect()
319        if isinstance(dialect, Dialect):
320            return dialect
321        if isinstance(dialect, str):
322            try:
323                dialect_name, *kv_pairs = dialect.split(",")
324                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
325            except ValueError:
326                raise ValueError(
327                    f"Invalid dialect format: '{dialect}'. "
328                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
329                )
330
331            result = cls.get(dialect_name.strip())
332            if not result:
333                from difflib import get_close_matches
334
335                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
336                if similar:
337                    similar = f" Did you mean {similar}?"
338
339                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
340
341            return result(**kwargs)
342
343        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
345    @classmethod
346    def format_time(
347        cls, expression: t.Optional[str | exp.Expression]
348    ) -> t.Optional[exp.Expression]:
349        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
350        if isinstance(expression, str):
351            return exp.Literal.string(
352                # the time formats are quoted
353                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
354            )
355
356        if expression and expression.is_string:
357            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
358
359        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
377    def normalize_identifier(self, expression: E) -> E:
378        """
379        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
380
381        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
382        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
383        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
384        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
385
386        There are also dialects like Spark, which are case-insensitive even when quotes are
387        present, and dialects like MySQL, whose resolution rules match those employed by the
388        underlying operating system, for example they may always be case-sensitive in Linux.
389
390        Finally, the normalization behavior of some engines can even be controlled through flags,
391        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
392
393        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
394        that it can analyze queries in the optimizer and successfully capture their semantics.
395        """
396        if (
397            isinstance(expression, exp.Identifier)
398            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
399            and (
400                not expression.quoted
401                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
402            )
403        ):
404            expression.set(
405                "this",
406                (
407                    expression.this.upper()
408                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
409                    else expression.this.lower()
410                ),
411            )
412
413        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
415    def case_sensitive(self, text: str) -> bool:
416        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
417        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
418            return False
419
420        unsafe = (
421            str.islower
422            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
423            else str.isupper
424        )
425        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
427    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
428        """Checks if text can be identified given an identify option.
429
430        Args:
431            text: The text to check.
432            identify:
433                `"always"` or `True`: Always returns `True`.
434                `"safe"`: Only returns `True` if the identifier is case-insensitive.
435
436        Returns:
437            Whether the given text can be identified.
438        """
439        if identify is True or identify == "always":
440            return True
441
442        if identify == "safe":
443            return not self.case_sensitive(text)
444
445        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
447    def quote_identifier(self, expression: E, identify: bool = True) -> E:
448        """
449        Adds quotes to a given identifier.
450
451        Args:
452            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
453            identify: If set to `False`, the quotes will only be added if the identifier is deemed
454                "unsafe", with respect to its characters and this dialect's normalization strategy.
455        """
456        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
457            name = expression.this
458            expression.set(
459                "quoted",
460                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
461            )
462
463        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
465    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
466        if isinstance(path, exp.Literal):
467            path_text = path.name
468            if path.is_number:
469                path_text = f"[{path_text}]"
470
471            try:
472                return parse_json_path(path_text)
473            except ParseError as e:
474                logger.warning(f"Invalid JSON path syntax. {str(e)}")
475
476        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
478    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
479        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
481    def parse_into(
482        self, expression_type: exp.IntoType, sql: str, **opts
483    ) -> t.List[t.Optional[exp.Expression]]:
484        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
486    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
487        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
489    def transpile(self, sql: str, **opts) -> t.List[str]:
490        return [
491            self.generate(expression, copy=False, **opts) if expression else ""
492            for expression in self.parse(sql)
493        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
495    def tokenize(self, sql: str) -> t.List[Token]:
496        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
498    @property
499    def tokenizer(self) -> Tokenizer:
500        if not hasattr(self, "_tokenizer"):
501            self._tokenizer = self.tokenizer_class(dialect=self)
502        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
504    def parser(self, **opts) -> Parser:
505        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
507    def generator(self, **opts) -> Generator:
508        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
514def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
515    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
518def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
519    if expression.args.get("accuracy"):
520        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
521    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
524def if_sql(
525    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
526) -> t.Callable[[Generator, exp.If], str]:
527    def _if_sql(self: Generator, expression: exp.If) -> str:
528        return self.func(
529            name,
530            expression.this,
531            expression.args.get("true"),
532            expression.args.get("false") or false_value,
533        )
534
535    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
538def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
539    this = expression.this
540    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
541        this.replace(exp.cast(this, "json"))
542
543    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
546def inline_array_sql(self: Generator, expression: exp.Array) -> str:
547    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
550def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
551    return self.like_sql(
552        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
553    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
556def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
557    zone = self.sql(expression, "this")
558    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
561def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
562    if expression.args.get("recursive"):
563        self.unsupported("Recursive CTEs are unsupported")
564        expression.args["recursive"] = False
565    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
568def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
569    n = self.sql(expression, "this")
570    d = self.sql(expression, "expression")
571    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
574def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
575    self.unsupported("TABLESAMPLE unsupported")
576    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
579def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
580    self.unsupported("PIVOT unsupported")
581    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
584def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
585    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
588def no_comment_column_constraint_sql(
589    self: Generator, expression: exp.CommentColumnConstraint
590) -> str:
591    self.unsupported("CommentColumnConstraint unsupported")
592    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
595def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
596    self.unsupported("MAP_FROM_ENTRIES unsupported")
597    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
600def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
601    this = self.sql(expression, "this")
602    substr = self.sql(expression, "substr")
603    position = self.sql(expression, "position")
604    if position:
605        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
606    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
609def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
610    return (
611        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
612    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
615def var_map_sql(
616    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
617) -> str:
618    keys = expression.args["keys"]
619    values = expression.args["values"]
620
621    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
622        self.unsupported("Cannot convert array columns into map.")
623        return self.func(map_func_name, keys, values)
624
625    args = []
626    for key, value in zip(keys.expressions, values.expressions):
627        args.append(self.sql(key))
628        args.append(self.sql(value))
629
630    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
633def build_formatted_time(
634    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
635) -> t.Callable[[t.List], E]:
636    """Helper used for time expressions.
637
638    Args:
639        exp_class: the expression class to instantiate.
640        dialect: target sql dialect.
641        default: the default format, True being time.
642
643    Returns:
644        A callable that can be used to return the appropriately formatted time expression.
645    """
646
647    def _builder(args: t.List):
648        return exp_class(
649            this=seq_get(args, 0),
650            format=Dialect[dialect].format_time(
651                seq_get(args, 1)
652                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
653            ),
654        )
655
656    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
659def time_format(
660    dialect: DialectType = None,
661) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
662    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
663        """
664        Returns the time format for a given expression, unless it's equivalent
665        to the default time format of the dialect of interest.
666        """
667        time_format = self.format_time(expression)
668        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
669
670    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
673def build_date_delta(
674    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
675) -> t.Callable[[t.List], E]:
676    def _builder(args: t.List) -> E:
677        unit_based = len(args) == 3
678        this = args[2] if unit_based else seq_get(args, 0)
679        unit = args[0] if unit_based else exp.Literal.string("DAY")
680        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
681        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
682
683    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
686def build_date_delta_with_interval(
687    expression_class: t.Type[E],
688) -> t.Callable[[t.List], t.Optional[E]]:
689    def _builder(args: t.List) -> t.Optional[E]:
690        if len(args) < 2:
691            return None
692
693        interval = args[1]
694
695        if not isinstance(interval, exp.Interval):
696            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
697
698        expression = interval.this
699        if expression and expression.is_string:
700            expression = exp.Literal.number(expression.this)
701
702        return expression_class(
703            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
704        )
705
706    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
709def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
710    unit = seq_get(args, 0)
711    this = seq_get(args, 1)
712
713    if isinstance(this, exp.Cast) and this.is_type("date"):
714        return exp.DateTrunc(unit=unit, this=this)
715    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
718def date_add_interval_sql(
719    data_type: str, kind: str
720) -> t.Callable[[Generator, exp.Expression], str]:
721    def func(self: Generator, expression: exp.Expression) -> str:
722        this = self.sql(expression, "this")
723        unit = expression.args.get("unit")
724        unit = exp.var(unit.name.upper() if unit else "DAY")
725        interval = exp.Interval(this=expression.expression, unit=unit)
726        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
727
728    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
731def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
732    return self.func(
733        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
734    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
737def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
738    if not expression.expression:
739        from sqlglot.optimizer.annotate_types import annotate_types
740
741        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
742        return self.sql(exp.cast(expression.this, to=target_type))
743    if expression.text("expression").lower() in TIMEZONES:
744        return self.sql(
745            exp.AtTimeZone(
746                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
747                zone=expression.expression,
748            )
749        )
750    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
753def locate_to_strposition(args: t.List) -> exp.Expression:
754    return exp.StrPosition(
755        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
756    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
759def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
760    return self.func(
761        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
762    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
765def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
766    return self.sql(
767        exp.Substring(
768            this=expression.this, start=exp.Literal.number(1), length=expression.expression
769        )
770    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
773def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
774    return self.sql(
775        exp.Substring(
776            this=expression.this,
777            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
778        )
779    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
782def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
783    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
786def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
787    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
791def encode_decode_sql(
792    self: Generator, expression: exp.Expression, name: str, replace: bool = True
793) -> str:
794    charset = expression.args.get("charset")
795    if charset and charset.name.lower() != "utf-8":
796        self.unsupported(f"Expected utf-8 character set, got {charset}.")
797
798    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
801def min_or_least(self: Generator, expression: exp.Min) -> str:
802    name = "LEAST" if expression.expressions else "MIN"
803    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
806def max_or_greatest(self: Generator, expression: exp.Max) -> str:
807    name = "GREATEST" if expression.expressions else "MAX"
808    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
811def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
812    cond = expression.this
813
814    if isinstance(expression.this, exp.Distinct):
815        cond = expression.this.expressions[0]
816        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
817
818    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
821def trim_sql(self: Generator, expression: exp.Trim) -> str:
822    target = self.sql(expression, "this")
823    trim_type = self.sql(expression, "position")
824    remove_chars = self.sql(expression, "expression")
825    collation = self.sql(expression, "collation")
826
827    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
828    if not remove_chars and not collation:
829        return self.trim_sql(expression)
830
831    trim_type = f"{trim_type} " if trim_type else ""
832    remove_chars = f"{remove_chars} " if remove_chars else ""
833    from_part = "FROM " if trim_type or remove_chars else ""
834    collation = f" COLLATE {collation}" if collation else ""
835    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
838def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
839    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
842def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
843    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
846def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
847    delim, *rest_args = expression.expressions
848    return self.sql(
849        reduce(
850            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
851            rest_args,
852        )
853    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
856def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
857    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
858    if bad_args:
859        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
860
861    return self.func(
862        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
863    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
866def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
867    bad_args = list(
868        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
869    )
870    if bad_args:
871        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
872
873    return self.func(
874        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
875    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
878def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
879    names = []
880    for agg in aggregations:
881        if isinstance(agg, exp.Alias):
882            names.append(agg.alias)
883        else:
884            """
885            This case corresponds to aggregations without aliases being used as suffixes
886            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
887            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
888            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
889            """
890            agg_all_unquoted = agg.transform(
891                lambda node: (
892                    exp.Identifier(this=node.name, quoted=False)
893                    if isinstance(node, exp.Identifier)
894                    else node
895                )
896            )
897            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
898
899    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
902def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
903    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
907def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
908    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
911def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
912    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
915def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
916    a = self.sql(expression.left)
917    b = self.sql(expression.right)
918    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
921def is_parse_json(expression: exp.Expression) -> bool:
922    return isinstance(expression, exp.ParseJSON) or (
923        isinstance(expression, exp.Cast) and expression.is_type("json")
924    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
927def isnull_to_is_null(args: t.List) -> exp.Expression:
928    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
931def generatedasidentitycolumnconstraint_sql(
932    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
933) -> str:
934    start = self.sql(expression, "start") or "1"
935    increment = self.sql(expression, "increment") or "1"
936    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
939def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
940    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
941        if expression.args.get("count"):
942            self.unsupported(f"Only two arguments are supported in function {name}.")
943
944        return self.func(name, expression.this, expression.expression)
945
946    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
949def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
950    this = expression.this.copy()
951
952    return_type = expression.return_type
953    if return_type.is_type(exp.DataType.Type.DATE):
954        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
955        # can truncate timestamp strings, because some dialects can't cast them to DATE
956        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
957
958    expression.this.replace(exp.cast(this, return_type))
959    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
962def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
963    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
964        if cast and isinstance(expression, exp.TsOrDsAdd):
965            expression = ts_or_ds_add_cast(expression)
966
967        return self.func(
968            name,
969            exp.var(expression.text("unit").upper() or "DAY"),
970            expression.expression,
971            expression.this,
972        )
973
974    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
977def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
978    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
979    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
980    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
981
982    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
 985def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 986    """Remove table refs from columns in when statements."""
 987    alias = expression.this.args.get("alias")
 988
 989    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 990        return self.dialect.normalize_identifier(identifier).name if identifier else None
 991
 992    targets = {normalize(expression.this.this)}
 993
 994    if alias:
 995        targets.add(normalize(alias.this))
 996
 997    for when in expression.expressions:
 998        when.transform(
 999            lambda node: (
1000                exp.column(node.this)
1001                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1002                else node
1003            ),
1004            copy=False,
1005        )
1006
1007    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True) -> Callable[[List], ~F]:
1010def build_json_extract_path(
1011    expr_type: t.Type[F], zero_based_indexing: bool = True
1012) -> t.Callable[[t.List], F]:
1013    def _builder(args: t.List) -> F:
1014        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1015        for arg in args[1:]:
1016            if not isinstance(arg, exp.Literal):
1017                # We use the fallback parser because we can't really transpile non-literals safely
1018                return expr_type.from_arg_list(args)
1019
1020            text = arg.name
1021            if is_int(text):
1022                index = int(text)
1023                segments.append(
1024                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1025                )
1026            else:
1027                segments.append(exp.JSONPathKey(this=text))
1028
1029        # This is done to avoid failing in the expression validator due to the arg count
1030        del args[2:]
1031        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1032
1033    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1036def json_extract_segments(
1037    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1038) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1039    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1040        path = expression.expression
1041        if not isinstance(path, exp.JSONPath):
1042            return rename_func(name)(self, expression)
1043
1044        segments = []
1045        for segment in path.expressions:
1046            path = self.sql(segment)
1047            if path:
1048                if isinstance(segment, exp.JSONPathPart) and (
1049                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1050                ):
1051                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1052
1053                segments.append(path)
1054
1055        if op:
1056            return f" {op} ".join([self.sql(expression.this), *segments])
1057        return self.func(name, expression.this, *segments)
1058
1059    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1062def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1063    if isinstance(expression.this, exp.JSONPathWildcard):
1064        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1065
1066    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1069def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1070    cond = expression.expression
1071    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1072        alias = cond.expressions[0]
1073        cond = cond.this
1074    elif isinstance(cond, exp.Predicate):
1075        alias = "_u"
1076    else:
1077        self.unsupported("Unsupported filter condition")
1078        return ""
1079
1080    unnest = exp.Unnest(expressions=[expression.this])
1081    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1082    return self.sql(exp.Array(expressions=[filtered]))