Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    ATHENA = "athena"
  35    BIGQUERY = "bigquery"
  36    CLICKHOUSE = "clickhouse"
  37    DATABRICKS = "databricks"
  38    DORIS = "doris"
  39    DRILL = "drill"
  40    DUCKDB = "duckdb"
  41    HIVE = "hive"
  42    MYSQL = "mysql"
  43    ORACLE = "oracle"
  44    POSTGRES = "postgres"
  45    PRESTO = "presto"
  46    PRQL = "prql"
  47    REDSHIFT = "redshift"
  48    SNOWFLAKE = "snowflake"
  49    SPARK = "spark"
  50    SPARK2 = "spark2"
  51    SQLITE = "sqlite"
  52    STARROCKS = "starrocks"
  53    TABLEAU = "tableau"
  54    TERADATA = "teradata"
  55    TRINO = "trino"
  56    TSQL = "tsql"
  57
  58
  59class NormalizationStrategy(str, AutoName):
  60    """Specifies the strategy according to which identifiers should be normalized."""
  61
  62    LOWERCASE = auto()
  63    """Unquoted identifiers are lowercased."""
  64
  65    UPPERCASE = auto()
  66    """Unquoted identifiers are uppercased."""
  67
  68    CASE_SENSITIVE = auto()
  69    """Always case-sensitive, regardless of quotes."""
  70
  71    CASE_INSENSITIVE = auto()
  72    """Always case-insensitive, regardless of quotes."""
  73
  74
  75class _Dialect(type):
  76    classes: t.Dict[str, t.Type[Dialect]] = {}
  77
  78    def __eq__(cls, other: t.Any) -> bool:
  79        if cls is other:
  80            return True
  81        if isinstance(other, str):
  82            return cls is cls.get(other)
  83        if isinstance(other, Dialect):
  84            return cls is type(other)
  85
  86        return False
  87
  88    def __hash__(cls) -> int:
  89        return hash(cls.__name__.lower())
  90
  91    @classmethod
  92    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  93        return cls.classes[key]
  94
  95    @classmethod
  96    def get(
  97        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  98    ) -> t.Optional[t.Type[Dialect]]:
  99        return cls.classes.get(key, default)
 100
 101    def __new__(cls, clsname, bases, attrs):
 102        klass = super().__new__(cls, clsname, bases, attrs)
 103        enum = Dialects.__members__.get(clsname.upper())
 104        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 105
 106        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 107        klass.FORMAT_TRIE = (
 108            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 109        )
 110        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 111        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 112
 113        base = seq_get(bases, 0)
 114        base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
 115        base_parser = (getattr(base, "parser_class", Parser),)
 116        base_generator = (getattr(base, "generator_class", Generator),)
 117
 118        klass.tokenizer_class = klass.__dict__.get(
 119            "Tokenizer", type("Tokenizer", base_tokenizer, {})
 120        )
 121        klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
 122        klass.generator_class = klass.__dict__.get(
 123            "Generator", type("Generator", base_generator, {})
 124        )
 125
 126        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 127        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 128            klass.tokenizer_class._IDENTIFIERS.items()
 129        )[0]
 130
 131        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 132            return next(
 133                (
 134                    (s, e)
 135                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 136                    if t == token_type
 137                ),
 138                (None, None),
 139            )
 140
 141        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 142        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 143        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 144        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 145
 146        if "\\" in klass.tokenizer_class.STRING_ESCAPES:
 147            klass.UNESCAPED_SEQUENCES = {
 148                "\\a": "\a",
 149                "\\b": "\b",
 150                "\\f": "\f",
 151                "\\n": "\n",
 152                "\\r": "\r",
 153                "\\t": "\t",
 154                "\\v": "\v",
 155                "\\\\": "\\",
 156                **klass.UNESCAPED_SEQUENCES,
 157            }
 158
 159        klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
 160
 161        if enum not in ("", "bigquery"):
 162            klass.generator_class.SELECT_KINDS = ()
 163
 164        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 165            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 166            for modifier in ("cluster", "distribute", "sort"):
 167                modifier_transforms.pop(modifier, None)
 168
 169            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 170
 171        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 172            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 173                TokenType.ANTI,
 174                TokenType.SEMI,
 175            }
 176
 177        return klass
 178
 179
 180class Dialect(metaclass=_Dialect):
 181    INDEX_OFFSET = 0
 182    """The base index offset for arrays."""
 183
 184    WEEK_OFFSET = 0
 185    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 186
 187    UNNEST_COLUMN_ONLY = False
 188    """Whether `UNNEST` table aliases are treated as column aliases."""
 189
 190    ALIAS_POST_TABLESAMPLE = False
 191    """Whether the table alias comes after tablesample."""
 192
 193    TABLESAMPLE_SIZE_IS_PERCENT = False
 194    """Whether a size in the table sample clause represents percentage."""
 195
 196    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 197    """Specifies the strategy according to which identifiers should be normalized."""
 198
 199    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 200    """Whether an unquoted identifier can start with a digit."""
 201
 202    DPIPE_IS_STRING_CONCAT = True
 203    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 204
 205    STRICT_STRING_CONCAT = False
 206    """Whether `CONCAT`'s arguments must be strings."""
 207
 208    SUPPORTS_USER_DEFINED_TYPES = True
 209    """Whether user-defined data types are supported."""
 210
 211    SUPPORTS_SEMI_ANTI_JOIN = True
 212    """Whether `SEMI` or `ANTI` joins are supported."""
 213
 214    NORMALIZE_FUNCTIONS: bool | str = "upper"
 215    """
 216    Determines how function names are going to be normalized.
 217    Possible values:
 218        "upper" or True: Convert names to uppercase.
 219        "lower": Convert names to lowercase.
 220        False: Disables function name normalization.
 221    """
 222
 223    LOG_BASE_FIRST: t.Optional[bool] = True
 224    """
 225    Whether the base comes first in the `LOG` function.
 226    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
 227    """
 228
 229    NULL_ORDERING = "nulls_are_small"
 230    """
 231    Default `NULL` ordering method to use if not explicitly set.
 232    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 233    """
 234
 235    TYPED_DIVISION = False
 236    """
 237    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 238    False means `a / b` is always float division.
 239    True means `a / b` is integer division if both `a` and `b` are integers.
 240    """
 241
 242    SAFE_DIVISION = False
 243    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 244
 245    CONCAT_COALESCE = False
 246    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 247
 248    DATE_FORMAT = "'%Y-%m-%d'"
 249    DATEINT_FORMAT = "'%Y%m%d'"
 250    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 251
 252    TIME_MAPPING: t.Dict[str, str] = {}
 253    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 254
 255    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 256    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 257    FORMAT_MAPPING: t.Dict[str, str] = {}
 258    """
 259    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 260    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 261    """
 262
 263    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
 264    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
 265
 266    PSEUDOCOLUMNS: t.Set[str] = set()
 267    """
 268    Columns that are auto-generated by the engine corresponding to this dialect.
 269    For example, such columns may be excluded from `SELECT *` queries.
 270    """
 271
 272    PREFER_CTE_ALIAS_COLUMN = False
 273    """
 274    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 275    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 276    any projection aliases in the subquery.
 277
 278    For example,
 279        WITH y(c) AS (
 280            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 281        ) SELECT c FROM y;
 282
 283        will be rewritten as
 284
 285        WITH y(c) AS (
 286            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 287        ) SELECT c FROM y;
 288    """
 289
 290    # --- Autofilled ---
 291
 292    tokenizer_class = Tokenizer
 293    parser_class = Parser
 294    generator_class = Generator
 295
 296    # A trie of the time_mapping keys
 297    TIME_TRIE: t.Dict = {}
 298    FORMAT_TRIE: t.Dict = {}
 299
 300    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 301    INVERSE_TIME_TRIE: t.Dict = {}
 302
 303    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
 304
 305    # Delimiters for string literals and identifiers
 306    QUOTE_START = "'"
 307    QUOTE_END = "'"
 308    IDENTIFIER_START = '"'
 309    IDENTIFIER_END = '"'
 310
 311    # Delimiters for bit, hex, byte and unicode literals
 312    BIT_START: t.Optional[str] = None
 313    BIT_END: t.Optional[str] = None
 314    HEX_START: t.Optional[str] = None
 315    HEX_END: t.Optional[str] = None
 316    BYTE_START: t.Optional[str] = None
 317    BYTE_END: t.Optional[str] = None
 318    UNICODE_START: t.Optional[str] = None
 319    UNICODE_END: t.Optional[str] = None
 320
 321    @classmethod
 322    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 323        """
 324        Look up a dialect in the global dialect registry and return it if it exists.
 325
 326        Args:
 327            dialect: The target dialect. If this is a string, it can be optionally followed by
 328                additional key-value pairs that are separated by commas and are used to specify
 329                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 330
 331        Example:
 332            >>> dialect = dialect_class = get_or_raise("duckdb")
 333            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 334
 335        Returns:
 336            The corresponding Dialect instance.
 337        """
 338
 339        if not dialect:
 340            return cls()
 341        if isinstance(dialect, _Dialect):
 342            return dialect()
 343        if isinstance(dialect, Dialect):
 344            return dialect
 345        if isinstance(dialect, str):
 346            try:
 347                dialect_name, *kv_pairs = dialect.split(",")
 348                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 349            except ValueError:
 350                raise ValueError(
 351                    f"Invalid dialect format: '{dialect}'. "
 352                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 353                )
 354
 355            result = cls.get(dialect_name.strip())
 356            if not result:
 357                from difflib import get_close_matches
 358
 359                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 360                if similar:
 361                    similar = f" Did you mean {similar}?"
 362
 363                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 364
 365            return result(**kwargs)
 366
 367        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 368
 369    @classmethod
 370    def format_time(
 371        cls, expression: t.Optional[str | exp.Expression]
 372    ) -> t.Optional[exp.Expression]:
 373        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 374        if isinstance(expression, str):
 375            return exp.Literal.string(
 376                # the time formats are quoted
 377                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 378            )
 379
 380        if expression and expression.is_string:
 381            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 382
 383        return expression
 384
 385    def __init__(self, **kwargs) -> None:
 386        normalization_strategy = kwargs.get("normalization_strategy")
 387
 388        if normalization_strategy is None:
 389            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 390        else:
 391            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 392
 393    def __eq__(self, other: t.Any) -> bool:
 394        # Does not currently take dialect state into account
 395        return type(self) == other
 396
 397    def __hash__(self) -> int:
 398        # Does not currently take dialect state into account
 399        return hash(type(self))
 400
 401    def normalize_identifier(self, expression: E) -> E:
 402        """
 403        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 404
 405        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 406        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 407        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 408        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 409
 410        There are also dialects like Spark, which are case-insensitive even when quotes are
 411        present, and dialects like MySQL, whose resolution rules match those employed by the
 412        underlying operating system, for example they may always be case-sensitive in Linux.
 413
 414        Finally, the normalization behavior of some engines can even be controlled through flags,
 415        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 416
 417        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 418        that it can analyze queries in the optimizer and successfully capture their semantics.
 419        """
 420        if (
 421            isinstance(expression, exp.Identifier)
 422            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 423            and (
 424                not expression.quoted
 425                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 426            )
 427        ):
 428            expression.set(
 429                "this",
 430                (
 431                    expression.this.upper()
 432                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 433                    else expression.this.lower()
 434                ),
 435            )
 436
 437        return expression
 438
 439    def case_sensitive(self, text: str) -> bool:
 440        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 441        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 442            return False
 443
 444        unsafe = (
 445            str.islower
 446            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 447            else str.isupper
 448        )
 449        return any(unsafe(char) for char in text)
 450
 451    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 452        """Checks if text can be identified given an identify option.
 453
 454        Args:
 455            text: The text to check.
 456            identify:
 457                `"always"` or `True`: Always returns `True`.
 458                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 459
 460        Returns:
 461            Whether the given text can be identified.
 462        """
 463        if identify is True or identify == "always":
 464            return True
 465
 466        if identify == "safe":
 467            return not self.case_sensitive(text)
 468
 469        return False
 470
 471    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 472        """
 473        Adds quotes to a given identifier.
 474
 475        Args:
 476            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 477            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 478                "unsafe", with respect to its characters and this dialect's normalization strategy.
 479        """
 480        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 481            name = expression.this
 482            expression.set(
 483                "quoted",
 484                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 485            )
 486
 487        return expression
 488
 489    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 490        if isinstance(path, exp.Literal):
 491            path_text = path.name
 492            if path.is_number:
 493                path_text = f"[{path_text}]"
 494
 495            try:
 496                return parse_json_path(path_text)
 497            except ParseError as e:
 498                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 499
 500        return path
 501
 502    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 503        return self.parser(**opts).parse(self.tokenize(sql), sql)
 504
 505    def parse_into(
 506        self, expression_type: exp.IntoType, sql: str, **opts
 507    ) -> t.List[t.Optional[exp.Expression]]:
 508        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 509
 510    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 511        return self.generator(**opts).generate(expression, copy=copy)
 512
 513    def transpile(self, sql: str, **opts) -> t.List[str]:
 514        return [
 515            self.generate(expression, copy=False, **opts) if expression else ""
 516            for expression in self.parse(sql)
 517        ]
 518
 519    def tokenize(self, sql: str) -> t.List[Token]:
 520        return self.tokenizer.tokenize(sql)
 521
 522    @property
 523    def tokenizer(self) -> Tokenizer:
 524        if not hasattr(self, "_tokenizer"):
 525            self._tokenizer = self.tokenizer_class(dialect=self)
 526        return self._tokenizer
 527
 528    def parser(self, **opts) -> Parser:
 529        return self.parser_class(dialect=self, **opts)
 530
 531    def generator(self, **opts) -> Generator:
 532        return self.generator_class(dialect=self, **opts)
 533
 534
 535DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 536
 537
 538def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 539    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 540
 541
 542def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 543    if expression.args.get("accuracy"):
 544        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 545    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 546
 547
 548def if_sql(
 549    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 550) -> t.Callable[[Generator, exp.If], str]:
 551    def _if_sql(self: Generator, expression: exp.If) -> str:
 552        return self.func(
 553            name,
 554            expression.this,
 555            expression.args.get("true"),
 556            expression.args.get("false") or false_value,
 557        )
 558
 559    return _if_sql
 560
 561
 562def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 563    this = expression.this
 564    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 565        this.replace(exp.cast(this, "json"))
 566
 567    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 568
 569
 570def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 571    return f"[{self.expressions(expression, flat=True)}]"
 572
 573
 574def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 575    return self.like_sql(
 576        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 577    )
 578
 579
 580def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 581    zone = self.sql(expression, "this")
 582    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 583
 584
 585def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 586    if expression.args.get("recursive"):
 587        self.unsupported("Recursive CTEs are unsupported")
 588        expression.args["recursive"] = False
 589    return self.with_sql(expression)
 590
 591
 592def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 593    n = self.sql(expression, "this")
 594    d = self.sql(expression, "expression")
 595    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 596
 597
 598def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 599    self.unsupported("TABLESAMPLE unsupported")
 600    return self.sql(expression.this)
 601
 602
 603def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 604    self.unsupported("PIVOT unsupported")
 605    return ""
 606
 607
 608def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 609    return self.cast_sql(expression)
 610
 611
 612def no_comment_column_constraint_sql(
 613    self: Generator, expression: exp.CommentColumnConstraint
 614) -> str:
 615    self.unsupported("CommentColumnConstraint unsupported")
 616    return ""
 617
 618
 619def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 620    self.unsupported("MAP_FROM_ENTRIES unsupported")
 621    return ""
 622
 623
 624def str_position_sql(
 625    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
 626) -> str:
 627    this = self.sql(expression, "this")
 628    substr = self.sql(expression, "substr")
 629    position = self.sql(expression, "position")
 630    instance = expression.args.get("instance") if generate_instance else None
 631    position_offset = ""
 632
 633    if position:
 634        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
 635        this = self.func("SUBSTR", this, position)
 636        position_offset = f" + {position} - 1"
 637
 638    return self.func("STRPOS", this, substr, instance) + position_offset
 639
 640
 641def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 642    return (
 643        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 644    )
 645
 646
 647def var_map_sql(
 648    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 649) -> str:
 650    keys = expression.args["keys"]
 651    values = expression.args["values"]
 652
 653    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 654        self.unsupported("Cannot convert array columns into map.")
 655        return self.func(map_func_name, keys, values)
 656
 657    args = []
 658    for key, value in zip(keys.expressions, values.expressions):
 659        args.append(self.sql(key))
 660        args.append(self.sql(value))
 661
 662    return self.func(map_func_name, *args)
 663
 664
 665def build_formatted_time(
 666    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 667) -> t.Callable[[t.List], E]:
 668    """Helper used for time expressions.
 669
 670    Args:
 671        exp_class: the expression class to instantiate.
 672        dialect: target sql dialect.
 673        default: the default format, True being time.
 674
 675    Returns:
 676        A callable that can be used to return the appropriately formatted time expression.
 677    """
 678
 679    def _builder(args: t.List):
 680        return exp_class(
 681            this=seq_get(args, 0),
 682            format=Dialect[dialect].format_time(
 683                seq_get(args, 1)
 684                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 685            ),
 686        )
 687
 688    return _builder
 689
 690
 691def time_format(
 692    dialect: DialectType = None,
 693) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 694    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 695        """
 696        Returns the time format for a given expression, unless it's equivalent
 697        to the default time format of the dialect of interest.
 698        """
 699        time_format = self.format_time(expression)
 700        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 701
 702    return _time_format
 703
 704
 705def build_date_delta(
 706    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 707) -> t.Callable[[t.List], E]:
 708    def _builder(args: t.List) -> E:
 709        unit_based = len(args) == 3
 710        this = args[2] if unit_based else seq_get(args, 0)
 711        unit = args[0] if unit_based else exp.Literal.string("DAY")
 712        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 713        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 714
 715    return _builder
 716
 717
 718def build_date_delta_with_interval(
 719    expression_class: t.Type[E],
 720) -> t.Callable[[t.List], t.Optional[E]]:
 721    def _builder(args: t.List) -> t.Optional[E]:
 722        if len(args) < 2:
 723            return None
 724
 725        interval = args[1]
 726
 727        if not isinstance(interval, exp.Interval):
 728            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 729
 730        expression = interval.this
 731        if expression and expression.is_string:
 732            expression = exp.Literal.number(expression.this)
 733
 734        return expression_class(
 735            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 736        )
 737
 738    return _builder
 739
 740
 741def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 742    unit = seq_get(args, 0)
 743    this = seq_get(args, 1)
 744
 745    if isinstance(this, exp.Cast) and this.is_type("date"):
 746        return exp.DateTrunc(unit=unit, this=this)
 747    return exp.TimestampTrunc(this=this, unit=unit)
 748
 749
 750def date_add_interval_sql(
 751    data_type: str, kind: str
 752) -> t.Callable[[Generator, exp.Expression], str]:
 753    def func(self: Generator, expression: exp.Expression) -> str:
 754        this = self.sql(expression, "this")
 755        unit = expression.args.get("unit")
 756        unit = exp.var(unit.name.upper() if unit else "DAY")
 757        interval = exp.Interval(this=expression.expression, unit=unit)
 758        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 759
 760    return func
 761
 762
 763def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 764    return self.func(
 765        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 766    )
 767
 768
 769def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 770    if not expression.expression:
 771        from sqlglot.optimizer.annotate_types import annotate_types
 772
 773        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 774        return self.sql(exp.cast(expression.this, to=target_type))
 775    if expression.text("expression").lower() in TIMEZONES:
 776        return self.sql(
 777            exp.AtTimeZone(
 778                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 779                zone=expression.expression,
 780            )
 781        )
 782    return self.func("TIMESTAMP", expression.this, expression.expression)
 783
 784
 785def locate_to_strposition(args: t.List) -> exp.Expression:
 786    return exp.StrPosition(
 787        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 788    )
 789
 790
 791def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 792    return self.func(
 793        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 794    )
 795
 796
 797def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 798    return self.sql(
 799        exp.Substring(
 800            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 801        )
 802    )
 803
 804
 805def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 806    return self.sql(
 807        exp.Substring(
 808            this=expression.this,
 809            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 810        )
 811    )
 812
 813
 814def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 815    return self.sql(exp.cast(expression.this, "timestamp"))
 816
 817
 818def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 819    return self.sql(exp.cast(expression.this, "date"))
 820
 821
 822# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 823def encode_decode_sql(
 824    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 825) -> str:
 826    charset = expression.args.get("charset")
 827    if charset and charset.name.lower() != "utf-8":
 828        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 829
 830    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 831
 832
 833def min_or_least(self: Generator, expression: exp.Min) -> str:
 834    name = "LEAST" if expression.expressions else "MIN"
 835    return rename_func(name)(self, expression)
 836
 837
 838def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 839    name = "GREATEST" if expression.expressions else "MAX"
 840    return rename_func(name)(self, expression)
 841
 842
 843def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 844    cond = expression.this
 845
 846    if isinstance(expression.this, exp.Distinct):
 847        cond = expression.this.expressions[0]
 848        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 849
 850    return self.func("sum", exp.func("if", cond, 1, 0))
 851
 852
 853def trim_sql(self: Generator, expression: exp.Trim) -> str:
 854    target = self.sql(expression, "this")
 855    trim_type = self.sql(expression, "position")
 856    remove_chars = self.sql(expression, "expression")
 857    collation = self.sql(expression, "collation")
 858
 859    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 860    if not remove_chars and not collation:
 861        return self.trim_sql(expression)
 862
 863    trim_type = f"{trim_type} " if trim_type else ""
 864    remove_chars = f"{remove_chars} " if remove_chars else ""
 865    from_part = "FROM " if trim_type or remove_chars else ""
 866    collation = f" COLLATE {collation}" if collation else ""
 867    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 868
 869
 870def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 871    return self.func("STRPTIME", expression.this, self.format_time(expression))
 872
 873
 874def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 875    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 876
 877
 878def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 879    delim, *rest_args = expression.expressions
 880    return self.sql(
 881        reduce(
 882            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 883            rest_args,
 884        )
 885    )
 886
 887
 888def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 889    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 890    if bad_args:
 891        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 892
 893    return self.func(
 894        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 895    )
 896
 897
 898def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 899    bad_args = list(
 900        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 901    )
 902    if bad_args:
 903        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 904
 905    return self.func(
 906        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 907    )
 908
 909
 910def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 911    names = []
 912    for agg in aggregations:
 913        if isinstance(agg, exp.Alias):
 914            names.append(agg.alias)
 915        else:
 916            """
 917            This case corresponds to aggregations without aliases being used as suffixes
 918            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 919            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 920            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 921            """
 922            agg_all_unquoted = agg.transform(
 923                lambda node: (
 924                    exp.Identifier(this=node.name, quoted=False)
 925                    if isinstance(node, exp.Identifier)
 926                    else node
 927                )
 928            )
 929            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 930
 931    return names
 932
 933
 934def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 935    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 936
 937
 938# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 939def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 940    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 941
 942
 943def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 944    return self.func("MAX", expression.this)
 945
 946
 947def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 948    a = self.sql(expression.left)
 949    b = self.sql(expression.right)
 950    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 951
 952
 953def is_parse_json(expression: exp.Expression) -> bool:
 954    return isinstance(expression, exp.ParseJSON) or (
 955        isinstance(expression, exp.Cast) and expression.is_type("json")
 956    )
 957
 958
 959def isnull_to_is_null(args: t.List) -> exp.Expression:
 960    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 961
 962
 963def generatedasidentitycolumnconstraint_sql(
 964    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 965) -> str:
 966    start = self.sql(expression, "start") or "1"
 967    increment = self.sql(expression, "increment") or "1"
 968    return f"IDENTITY({start}, {increment})"
 969
 970
 971def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 972    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 973        if expression.args.get("count"):
 974            self.unsupported(f"Only two arguments are supported in function {name}.")
 975
 976        return self.func(name, expression.this, expression.expression)
 977
 978    return _arg_max_or_min_sql
 979
 980
 981def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 982    this = expression.this.copy()
 983
 984    return_type = expression.return_type
 985    if return_type.is_type(exp.DataType.Type.DATE):
 986        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 987        # can truncate timestamp strings, because some dialects can't cast them to DATE
 988        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 989
 990    expression.this.replace(exp.cast(this, return_type))
 991    return expression
 992
 993
 994def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 995    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 996        if cast and isinstance(expression, exp.TsOrDsAdd):
 997            expression = ts_or_ds_add_cast(expression)
 998
 999        return self.func(
1000            name,
1001            exp.var(expression.text("unit").upper() or "DAY"),
1002            expression.expression,
1003            expression.this,
1004        )
1005
1006    return _delta_sql
1007
1008
1009def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1010    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1011    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1012    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1013
1014    return self.sql(exp.cast(minus_one_day, "date"))
1015
1016
1017def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1018    """Remove table refs from columns in when statements."""
1019    alias = expression.this.args.get("alias")
1020
1021    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1022        return self.dialect.normalize_identifier(identifier).name if identifier else None
1023
1024    targets = {normalize(expression.this.this)}
1025
1026    if alias:
1027        targets.add(normalize(alias.this))
1028
1029    for when in expression.expressions:
1030        when.transform(
1031            lambda node: (
1032                exp.column(node.this)
1033                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1034                else node
1035            ),
1036            copy=False,
1037        )
1038
1039    return self.merge_sql(expression)
1040
1041
1042def build_json_extract_path(
1043    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1044) -> t.Callable[[t.List], F]:
1045    def _builder(args: t.List) -> F:
1046        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1047        for arg in args[1:]:
1048            if not isinstance(arg, exp.Literal):
1049                # We use the fallback parser because we can't really transpile non-literals safely
1050                return expr_type.from_arg_list(args)
1051
1052            text = arg.name
1053            if is_int(text):
1054                index = int(text)
1055                segments.append(
1056                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1057                )
1058            else:
1059                segments.append(exp.JSONPathKey(this=text))
1060
1061        # This is done to avoid failing in the expression validator due to the arg count
1062        del args[2:]
1063        return expr_type(
1064            this=seq_get(args, 0),
1065            expression=exp.JSONPath(expressions=segments),
1066            only_json_types=arrow_req_json_type,
1067        )
1068
1069    return _builder
1070
1071
1072def json_extract_segments(
1073    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1074) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1075    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1076        path = expression.expression
1077        if not isinstance(path, exp.JSONPath):
1078            return rename_func(name)(self, expression)
1079
1080        segments = []
1081        for segment in path.expressions:
1082            path = self.sql(segment)
1083            if path:
1084                if isinstance(segment, exp.JSONPathPart) and (
1085                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1086                ):
1087                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1088
1089                segments.append(path)
1090
1091        if op:
1092            return f" {op} ".join([self.sql(expression.this), *segments])
1093        return self.func(name, expression.this, *segments)
1094
1095    return _json_extract_segments
1096
1097
1098def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1099    if isinstance(expression.this, exp.JSONPathWildcard):
1100        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1101
1102    return expression.name
1103
1104
1105def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1106    cond = expression.expression
1107    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1108        alias = cond.expressions[0]
1109        cond = cond.this
1110    elif isinstance(cond, exp.Predicate):
1111        alias = "_u"
1112    else:
1113        self.unsupported("Unsupported filter condition")
1114        return ""
1115
1116    unnest = exp.Unnest(expressions=[expression.this])
1117    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1118    return self.sql(exp.Array(expressions=[filtered]))
1119
1120
1121def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1122    return self.func(
1123        "TO_NUMBER",
1124        expression.this,
1125        expression.args.get("format"),
1126        expression.args.get("nlsparam"),
1127    )
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    ATHENA = "athena"
36    BIGQUERY = "bigquery"
37    CLICKHOUSE = "clickhouse"
38    DATABRICKS = "databricks"
39    DORIS = "doris"
40    DRILL = "drill"
41    DUCKDB = "duckdb"
42    HIVE = "hive"
43    MYSQL = "mysql"
44    ORACLE = "oracle"
45    POSTGRES = "postgres"
46    PRESTO = "presto"
47    PRQL = "prql"
48    REDSHIFT = "redshift"
49    SNOWFLAKE = "snowflake"
50    SPARK = "spark"
51    SPARK2 = "spark2"
52    SQLITE = "sqlite"
53    STARROCKS = "starrocks"
54    TABLEAU = "tableau"
55    TERADATA = "teradata"
56    TRINO = "trino"
57    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
ATHENA = <Dialects.ATHENA: 'athena'>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
60class NormalizationStrategy(str, AutoName):
61    """Specifies the strategy according to which identifiers should be normalized."""
62
63    LOWERCASE = auto()
64    """Unquoted identifiers are lowercased."""
65
66    UPPERCASE = auto()
67    """Unquoted identifiers are uppercased."""
68
69    CASE_SENSITIVE = auto()
70    """Always case-sensitive, regardless of quotes."""
71
72    CASE_INSENSITIVE = auto()
73    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
181class Dialect(metaclass=_Dialect):
182    INDEX_OFFSET = 0
183    """The base index offset for arrays."""
184
185    WEEK_OFFSET = 0
186    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
187
188    UNNEST_COLUMN_ONLY = False
189    """Whether `UNNEST` table aliases are treated as column aliases."""
190
191    ALIAS_POST_TABLESAMPLE = False
192    """Whether the table alias comes after tablesample."""
193
194    TABLESAMPLE_SIZE_IS_PERCENT = False
195    """Whether a size in the table sample clause represents percentage."""
196
197    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
198    """Specifies the strategy according to which identifiers should be normalized."""
199
200    IDENTIFIERS_CAN_START_WITH_DIGIT = False
201    """Whether an unquoted identifier can start with a digit."""
202
203    DPIPE_IS_STRING_CONCAT = True
204    """Whether the DPIPE token (`||`) is a string concatenation operator."""
205
206    STRICT_STRING_CONCAT = False
207    """Whether `CONCAT`'s arguments must be strings."""
208
209    SUPPORTS_USER_DEFINED_TYPES = True
210    """Whether user-defined data types are supported."""
211
212    SUPPORTS_SEMI_ANTI_JOIN = True
213    """Whether `SEMI` or `ANTI` joins are supported."""
214
215    NORMALIZE_FUNCTIONS: bool | str = "upper"
216    """
217    Determines how function names are going to be normalized.
218    Possible values:
219        "upper" or True: Convert names to uppercase.
220        "lower": Convert names to lowercase.
221        False: Disables function name normalization.
222    """
223
224    LOG_BASE_FIRST: t.Optional[bool] = True
225    """
226    Whether the base comes first in the `LOG` function.
227    Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
228    """
229
230    NULL_ORDERING = "nulls_are_small"
231    """
232    Default `NULL` ordering method to use if not explicitly set.
233    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
234    """
235
236    TYPED_DIVISION = False
237    """
238    Whether the behavior of `a / b` depends on the types of `a` and `b`.
239    False means `a / b` is always float division.
240    True means `a / b` is integer division if both `a` and `b` are integers.
241    """
242
243    SAFE_DIVISION = False
244    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
245
246    CONCAT_COALESCE = False
247    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
248
249    DATE_FORMAT = "'%Y-%m-%d'"
250    DATEINT_FORMAT = "'%Y%m%d'"
251    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
252
253    TIME_MAPPING: t.Dict[str, str] = {}
254    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
255
256    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
257    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
258    FORMAT_MAPPING: t.Dict[str, str] = {}
259    """
260    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
261    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
262    """
263
264    UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
265    """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
266
267    PSEUDOCOLUMNS: t.Set[str] = set()
268    """
269    Columns that are auto-generated by the engine corresponding to this dialect.
270    For example, such columns may be excluded from `SELECT *` queries.
271    """
272
273    PREFER_CTE_ALIAS_COLUMN = False
274    """
275    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
276    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
277    any projection aliases in the subquery.
278
279    For example,
280        WITH y(c) AS (
281            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
282        ) SELECT c FROM y;
283
284        will be rewritten as
285
286        WITH y(c) AS (
287            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
288        ) SELECT c FROM y;
289    """
290
291    # --- Autofilled ---
292
293    tokenizer_class = Tokenizer
294    parser_class = Parser
295    generator_class = Generator
296
297    # A trie of the time_mapping keys
298    TIME_TRIE: t.Dict = {}
299    FORMAT_TRIE: t.Dict = {}
300
301    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
302    INVERSE_TIME_TRIE: t.Dict = {}
303
304    ESCAPED_SEQUENCES: t.Dict[str, str] = {}
305
306    # Delimiters for string literals and identifiers
307    QUOTE_START = "'"
308    QUOTE_END = "'"
309    IDENTIFIER_START = '"'
310    IDENTIFIER_END = '"'
311
312    # Delimiters for bit, hex, byte and unicode literals
313    BIT_START: t.Optional[str] = None
314    BIT_END: t.Optional[str] = None
315    HEX_START: t.Optional[str] = None
316    HEX_END: t.Optional[str] = None
317    BYTE_START: t.Optional[str] = None
318    BYTE_END: t.Optional[str] = None
319    UNICODE_START: t.Optional[str] = None
320    UNICODE_END: t.Optional[str] = None
321
322    @classmethod
323    def get_or_raise(cls, dialect: DialectType) -> Dialect:
324        """
325        Look up a dialect in the global dialect registry and return it if it exists.
326
327        Args:
328            dialect: The target dialect. If this is a string, it can be optionally followed by
329                additional key-value pairs that are separated by commas and are used to specify
330                dialect settings, such as whether the dialect's identifiers are case-sensitive.
331
332        Example:
333            >>> dialect = dialect_class = get_or_raise("duckdb")
334            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
335
336        Returns:
337            The corresponding Dialect instance.
338        """
339
340        if not dialect:
341            return cls()
342        if isinstance(dialect, _Dialect):
343            return dialect()
344        if isinstance(dialect, Dialect):
345            return dialect
346        if isinstance(dialect, str):
347            try:
348                dialect_name, *kv_pairs = dialect.split(",")
349                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
350            except ValueError:
351                raise ValueError(
352                    f"Invalid dialect format: '{dialect}'. "
353                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
354                )
355
356            result = cls.get(dialect_name.strip())
357            if not result:
358                from difflib import get_close_matches
359
360                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
361                if similar:
362                    similar = f" Did you mean {similar}?"
363
364                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
365
366            return result(**kwargs)
367
368        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
369
370    @classmethod
371    def format_time(
372        cls, expression: t.Optional[str | exp.Expression]
373    ) -> t.Optional[exp.Expression]:
374        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
375        if isinstance(expression, str):
376            return exp.Literal.string(
377                # the time formats are quoted
378                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
379            )
380
381        if expression and expression.is_string:
382            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
383
384        return expression
385
386    def __init__(self, **kwargs) -> None:
387        normalization_strategy = kwargs.get("normalization_strategy")
388
389        if normalization_strategy is None:
390            self.normalization_strategy = self.NORMALIZATION_STRATEGY
391        else:
392            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
393
394    def __eq__(self, other: t.Any) -> bool:
395        # Does not currently take dialect state into account
396        return type(self) == other
397
398    def __hash__(self) -> int:
399        # Does not currently take dialect state into account
400        return hash(type(self))
401
402    def normalize_identifier(self, expression: E) -> E:
403        """
404        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
405
406        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
407        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
408        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
409        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
410
411        There are also dialects like Spark, which are case-insensitive even when quotes are
412        present, and dialects like MySQL, whose resolution rules match those employed by the
413        underlying operating system, for example they may always be case-sensitive in Linux.
414
415        Finally, the normalization behavior of some engines can even be controlled through flags,
416        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
417
418        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
419        that it can analyze queries in the optimizer and successfully capture their semantics.
420        """
421        if (
422            isinstance(expression, exp.Identifier)
423            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
424            and (
425                not expression.quoted
426                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
427            )
428        ):
429            expression.set(
430                "this",
431                (
432                    expression.this.upper()
433                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
434                    else expression.this.lower()
435                ),
436            )
437
438        return expression
439
440    def case_sensitive(self, text: str) -> bool:
441        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
442        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
443            return False
444
445        unsafe = (
446            str.islower
447            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
448            else str.isupper
449        )
450        return any(unsafe(char) for char in text)
451
452    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
453        """Checks if text can be identified given an identify option.
454
455        Args:
456            text: The text to check.
457            identify:
458                `"always"` or `True`: Always returns `True`.
459                `"safe"`: Only returns `True` if the identifier is case-insensitive.
460
461        Returns:
462            Whether the given text can be identified.
463        """
464        if identify is True or identify == "always":
465            return True
466
467        if identify == "safe":
468            return not self.case_sensitive(text)
469
470        return False
471
472    def quote_identifier(self, expression: E, identify: bool = True) -> E:
473        """
474        Adds quotes to a given identifier.
475
476        Args:
477            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
478            identify: If set to `False`, the quotes will only be added if the identifier is deemed
479                "unsafe", with respect to its characters and this dialect's normalization strategy.
480        """
481        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
482            name = expression.this
483            expression.set(
484                "quoted",
485                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
486            )
487
488        return expression
489
490    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
491        if isinstance(path, exp.Literal):
492            path_text = path.name
493            if path.is_number:
494                path_text = f"[{path_text}]"
495
496            try:
497                return parse_json_path(path_text)
498            except ParseError as e:
499                logger.warning(f"Invalid JSON path syntax. {str(e)}")
500
501        return path
502
503    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
504        return self.parser(**opts).parse(self.tokenize(sql), sql)
505
506    def parse_into(
507        self, expression_type: exp.IntoType, sql: str, **opts
508    ) -> t.List[t.Optional[exp.Expression]]:
509        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
510
511    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
512        return self.generator(**opts).generate(expression, copy=copy)
513
514    def transpile(self, sql: str, **opts) -> t.List[str]:
515        return [
516            self.generate(expression, copy=False, **opts) if expression else ""
517            for expression in self.parse(sql)
518        ]
519
520    def tokenize(self, sql: str) -> t.List[Token]:
521        return self.tokenizer.tokenize(sql)
522
523    @property
524    def tokenizer(self) -> Tokenizer:
525        if not hasattr(self, "_tokenizer"):
526            self._tokenizer = self.tokenizer_class(dialect=self)
527        return self._tokenizer
528
529    def parser(self, **opts) -> Parser:
530        return self.parser_class(dialect=self, **opts)
531
532    def generator(self, **opts) -> Generator:
533        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
386    def __init__(self, **kwargs) -> None:
387        normalization_strategy = kwargs.get("normalization_strategy")
388
389        if normalization_strategy is None:
390            self.normalization_strategy = self.NORMALIZATION_STRATEGY
391        else:
392            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST: Optional[bool] = True

Whether the base comes first in the LOG function. Possible values: True, False, None (two arguments are not supported by LOG)

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

UNESCAPED_SEQUENCES: Dict[str, str] = {}

Mapping of an escaped sequence (\n) to its unescaped version ( ).

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
ESCAPED_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
322    @classmethod
323    def get_or_raise(cls, dialect: DialectType) -> Dialect:
324        """
325        Look up a dialect in the global dialect registry and return it if it exists.
326
327        Args:
328            dialect: The target dialect. If this is a string, it can be optionally followed by
329                additional key-value pairs that are separated by commas and are used to specify
330                dialect settings, such as whether the dialect's identifiers are case-sensitive.
331
332        Example:
333            >>> dialect = dialect_class = get_or_raise("duckdb")
334            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
335
336        Returns:
337            The corresponding Dialect instance.
338        """
339
340        if not dialect:
341            return cls()
342        if isinstance(dialect, _Dialect):
343            return dialect()
344        if isinstance(dialect, Dialect):
345            return dialect
346        if isinstance(dialect, str):
347            try:
348                dialect_name, *kv_pairs = dialect.split(",")
349                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
350            except ValueError:
351                raise ValueError(
352                    f"Invalid dialect format: '{dialect}'. "
353                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
354                )
355
356            result = cls.get(dialect_name.strip())
357            if not result:
358                from difflib import get_close_matches
359
360                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
361                if similar:
362                    similar = f" Did you mean {similar}?"
363
364                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
365
366            return result(**kwargs)
367
368        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
370    @classmethod
371    def format_time(
372        cls, expression: t.Optional[str | exp.Expression]
373    ) -> t.Optional[exp.Expression]:
374        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
375        if isinstance(expression, str):
376            return exp.Literal.string(
377                # the time formats are quoted
378                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
379            )
380
381        if expression and expression.is_string:
382            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
383
384        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
402    def normalize_identifier(self, expression: E) -> E:
403        """
404        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
405
406        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
407        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
408        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
409        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
410
411        There are also dialects like Spark, which are case-insensitive even when quotes are
412        present, and dialects like MySQL, whose resolution rules match those employed by the
413        underlying operating system, for example they may always be case-sensitive in Linux.
414
415        Finally, the normalization behavior of some engines can even be controlled through flags,
416        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
417
418        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
419        that it can analyze queries in the optimizer and successfully capture their semantics.
420        """
421        if (
422            isinstance(expression, exp.Identifier)
423            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
424            and (
425                not expression.quoted
426                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
427            )
428        ):
429            expression.set(
430                "this",
431                (
432                    expression.this.upper()
433                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
434                    else expression.this.lower()
435                ),
436            )
437
438        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
440    def case_sensitive(self, text: str) -> bool:
441        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
442        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
443            return False
444
445        unsafe = (
446            str.islower
447            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
448            else str.isupper
449        )
450        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
452    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
453        """Checks if text can be identified given an identify option.
454
455        Args:
456            text: The text to check.
457            identify:
458                `"always"` or `True`: Always returns `True`.
459                `"safe"`: Only returns `True` if the identifier is case-insensitive.
460
461        Returns:
462            Whether the given text can be identified.
463        """
464        if identify is True or identify == "always":
465            return True
466
467        if identify == "safe":
468            return not self.case_sensitive(text)
469
470        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
472    def quote_identifier(self, expression: E, identify: bool = True) -> E:
473        """
474        Adds quotes to a given identifier.
475
476        Args:
477            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
478            identify: If set to `False`, the quotes will only be added if the identifier is deemed
479                "unsafe", with respect to its characters and this dialect's normalization strategy.
480        """
481        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
482            name = expression.this
483            expression.set(
484                "quoted",
485                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
486            )
487
488        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
490    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
491        if isinstance(path, exp.Literal):
492            path_text = path.name
493            if path.is_number:
494                path_text = f"[{path_text}]"
495
496            try:
497                return parse_json_path(path_text)
498            except ParseError as e:
499                logger.warning(f"Invalid JSON path syntax. {str(e)}")
500
501        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
503    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
504        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
506    def parse_into(
507        self, expression_type: exp.IntoType, sql: str, **opts
508    ) -> t.List[t.Optional[exp.Expression]]:
509        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
511    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
512        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
514    def transpile(self, sql: str, **opts) -> t.List[str]:
515        return [
516            self.generate(expression, copy=False, **opts) if expression else ""
517            for expression in self.parse(sql)
518        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
520    def tokenize(self, sql: str) -> t.List[Token]:
521        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
523    @property
524    def tokenizer(self) -> Tokenizer:
525        if not hasattr(self, "_tokenizer"):
526            self._tokenizer = self.tokenizer_class(dialect=self)
527        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
529    def parser(self, **opts) -> Parser:
530        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
532    def generator(self, **opts) -> Generator:
533        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
539def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
540    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
543def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
544    if expression.args.get("accuracy"):
545        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
546    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
549def if_sql(
550    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
551) -> t.Callable[[Generator, exp.If], str]:
552    def _if_sql(self: Generator, expression: exp.If) -> str:
553        return self.func(
554            name,
555            expression.this,
556            expression.args.get("true"),
557            expression.args.get("false") or false_value,
558        )
559
560    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
563def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
564    this = expression.this
565    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
566        this.replace(exp.cast(this, "json"))
567
568    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
571def inline_array_sql(self: Generator, expression: exp.Array) -> str:
572    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
575def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
576    return self.like_sql(
577        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
578    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
581def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
582    zone = self.sql(expression, "this")
583    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
586def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
587    if expression.args.get("recursive"):
588        self.unsupported("Recursive CTEs are unsupported")
589        expression.args["recursive"] = False
590    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
593def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
594    n = self.sql(expression, "this")
595    d = self.sql(expression, "expression")
596    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
599def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
600    self.unsupported("TABLESAMPLE unsupported")
601    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
604def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
605    self.unsupported("PIVOT unsupported")
606    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
609def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
610    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
613def no_comment_column_constraint_sql(
614    self: Generator, expression: exp.CommentColumnConstraint
615) -> str:
616    self.unsupported("CommentColumnConstraint unsupported")
617    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
620def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
621    self.unsupported("MAP_FROM_ENTRIES unsupported")
622    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition, generate_instance: bool = False) -> str:
625def str_position_sql(
626    self: Generator, expression: exp.StrPosition, generate_instance: bool = False
627) -> str:
628    this = self.sql(expression, "this")
629    substr = self.sql(expression, "substr")
630    position = self.sql(expression, "position")
631    instance = expression.args.get("instance") if generate_instance else None
632    position_offset = ""
633
634    if position:
635        # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
636        this = self.func("SUBSTR", this, position)
637        position_offset = f" + {position} - 1"
638
639    return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
642def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
643    return (
644        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
645    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
648def var_map_sql(
649    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
650) -> str:
651    keys = expression.args["keys"]
652    values = expression.args["values"]
653
654    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
655        self.unsupported("Cannot convert array columns into map.")
656        return self.func(map_func_name, keys, values)
657
658    args = []
659    for key, value in zip(keys.expressions, values.expressions):
660        args.append(self.sql(key))
661        args.append(self.sql(value))
662
663    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
666def build_formatted_time(
667    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
668) -> t.Callable[[t.List], E]:
669    """Helper used for time expressions.
670
671    Args:
672        exp_class: the expression class to instantiate.
673        dialect: target sql dialect.
674        default: the default format, True being time.
675
676    Returns:
677        A callable that can be used to return the appropriately formatted time expression.
678    """
679
680    def _builder(args: t.List):
681        return exp_class(
682            this=seq_get(args, 0),
683            format=Dialect[dialect].format_time(
684                seq_get(args, 1)
685                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
686            ),
687        )
688
689    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
692def time_format(
693    dialect: DialectType = None,
694) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
695    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
696        """
697        Returns the time format for a given expression, unless it's equivalent
698        to the default time format of the dialect of interest.
699        """
700        time_format = self.format_time(expression)
701        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
702
703    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
706def build_date_delta(
707    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
708) -> t.Callable[[t.List], E]:
709    def _builder(args: t.List) -> E:
710        unit_based = len(args) == 3
711        this = args[2] if unit_based else seq_get(args, 0)
712        unit = args[0] if unit_based else exp.Literal.string("DAY")
713        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
714        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
715
716    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
719def build_date_delta_with_interval(
720    expression_class: t.Type[E],
721) -> t.Callable[[t.List], t.Optional[E]]:
722    def _builder(args: t.List) -> t.Optional[E]:
723        if len(args) < 2:
724            return None
725
726        interval = args[1]
727
728        if not isinstance(interval, exp.Interval):
729            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
730
731        expression = interval.this
732        if expression and expression.is_string:
733            expression = exp.Literal.number(expression.this)
734
735        return expression_class(
736            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
737        )
738
739    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
742def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
743    unit = seq_get(args, 0)
744    this = seq_get(args, 1)
745
746    if isinstance(this, exp.Cast) and this.is_type("date"):
747        return exp.DateTrunc(unit=unit, this=this)
748    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
751def date_add_interval_sql(
752    data_type: str, kind: str
753) -> t.Callable[[Generator, exp.Expression], str]:
754    def func(self: Generator, expression: exp.Expression) -> str:
755        this = self.sql(expression, "this")
756        unit = expression.args.get("unit")
757        unit = exp.var(unit.name.upper() if unit else "DAY")
758        interval = exp.Interval(this=expression.expression, unit=unit)
759        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
760
761    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
764def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
765    return self.func(
766        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
767    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
770def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
771    if not expression.expression:
772        from sqlglot.optimizer.annotate_types import annotate_types
773
774        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
775        return self.sql(exp.cast(expression.this, to=target_type))
776    if expression.text("expression").lower() in TIMEZONES:
777        return self.sql(
778            exp.AtTimeZone(
779                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
780                zone=expression.expression,
781            )
782        )
783    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
786def locate_to_strposition(args: t.List) -> exp.Expression:
787    return exp.StrPosition(
788        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
789    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
792def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
793    return self.func(
794        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
795    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
798def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
799    return self.sql(
800        exp.Substring(
801            this=expression.this, start=exp.Literal.number(1), length=expression.expression
802        )
803    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
806def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
807    return self.sql(
808        exp.Substring(
809            this=expression.this,
810            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
811        )
812    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
815def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
816    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
819def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
820    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
824def encode_decode_sql(
825    self: Generator, expression: exp.Expression, name: str, replace: bool = True
826) -> str:
827    charset = expression.args.get("charset")
828    if charset and charset.name.lower() != "utf-8":
829        self.unsupported(f"Expected utf-8 character set, got {charset}.")
830
831    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
834def min_or_least(self: Generator, expression: exp.Min) -> str:
835    name = "LEAST" if expression.expressions else "MIN"
836    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
839def max_or_greatest(self: Generator, expression: exp.Max) -> str:
840    name = "GREATEST" if expression.expressions else "MAX"
841    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
844def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
845    cond = expression.this
846
847    if isinstance(expression.this, exp.Distinct):
848        cond = expression.this.expressions[0]
849        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
850
851    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
854def trim_sql(self: Generator, expression: exp.Trim) -> str:
855    target = self.sql(expression, "this")
856    trim_type = self.sql(expression, "position")
857    remove_chars = self.sql(expression, "expression")
858    collation = self.sql(expression, "collation")
859
860    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
861    if not remove_chars and not collation:
862        return self.trim_sql(expression)
863
864    trim_type = f"{trim_type} " if trim_type else ""
865    remove_chars = f"{remove_chars} " if remove_chars else ""
866    from_part = "FROM " if trim_type or remove_chars else ""
867    collation = f" COLLATE {collation}" if collation else ""
868    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
871def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
872    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
875def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
876    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
879def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
880    delim, *rest_args = expression.expressions
881    return self.sql(
882        reduce(
883            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
884            rest_args,
885        )
886    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
889def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
890    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
891    if bad_args:
892        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
893
894    return self.func(
895        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
896    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
899def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
900    bad_args = list(
901        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
902    )
903    if bad_args:
904        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
905
906    return self.func(
907        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
908    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
911def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
912    names = []
913    for agg in aggregations:
914        if isinstance(agg, exp.Alias):
915            names.append(agg.alias)
916        else:
917            """
918            This case corresponds to aggregations without aliases being used as suffixes
919            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
920            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
921            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
922            """
923            agg_all_unquoted = agg.transform(
924                lambda node: (
925                    exp.Identifier(this=node.name, quoted=False)
926                    if isinstance(node, exp.Identifier)
927                    else node
928                )
929            )
930            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
931
932    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
935def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
936    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
940def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
941    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
944def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
945    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
948def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
949    a = self.sql(expression.left)
950    b = self.sql(expression.right)
951    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
954def is_parse_json(expression: exp.Expression) -> bool:
955    return isinstance(expression, exp.ParseJSON) or (
956        isinstance(expression, exp.Cast) and expression.is_type("json")
957    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
960def isnull_to_is_null(args: t.List) -> exp.Expression:
961    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
964def generatedasidentitycolumnconstraint_sql(
965    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
966) -> str:
967    start = self.sql(expression, "start") or "1"
968    increment = self.sql(expression, "increment") or "1"
969    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
972def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
973    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
974        if expression.args.get("count"):
975            self.unsupported(f"Only two arguments are supported in function {name}.")
976
977        return self.func(name, expression.this, expression.expression)
978
979    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
982def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
983    this = expression.this.copy()
984
985    return_type = expression.return_type
986    if return_type.is_type(exp.DataType.Type.DATE):
987        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
988        # can truncate timestamp strings, because some dialects can't cast them to DATE
989        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
990
991    expression.this.replace(exp.cast(this, return_type))
992    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
 995def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 996    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 997        if cast and isinstance(expression, exp.TsOrDsAdd):
 998            expression = ts_or_ds_add_cast(expression)
 999
1000        return self.func(
1001            name,
1002            exp.var(expression.text("unit").upper() or "DAY"),
1003            expression.expression,
1004            expression.this,
1005        )
1006
1007    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
1010def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
1011    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
1012    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
1013    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
1014
1015    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1018def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1019    """Remove table refs from columns in when statements."""
1020    alias = expression.this.args.get("alias")
1021
1022    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
1023        return self.dialect.normalize_identifier(identifier).name if identifier else None
1024
1025    targets = {normalize(expression.this.this)}
1026
1027    if alias:
1028        targets.add(normalize(alias.this))
1029
1030    for when in expression.expressions:
1031        when.transform(
1032            lambda node: (
1033                exp.column(node.this)
1034                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1035                else node
1036            ),
1037            copy=False,
1038        )
1039
1040    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False) -> Callable[[List], ~F]:
1043def build_json_extract_path(
1044    expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
1045) -> t.Callable[[t.List], F]:
1046    def _builder(args: t.List) -> F:
1047        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1048        for arg in args[1:]:
1049            if not isinstance(arg, exp.Literal):
1050                # We use the fallback parser because we can't really transpile non-literals safely
1051                return expr_type.from_arg_list(args)
1052
1053            text = arg.name
1054            if is_int(text):
1055                index = int(text)
1056                segments.append(
1057                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1058                )
1059            else:
1060                segments.append(exp.JSONPathKey(this=text))
1061
1062        # This is done to avoid failing in the expression validator due to the arg count
1063        del args[2:]
1064        return expr_type(
1065            this=seq_get(args, 0),
1066            expression=exp.JSONPath(expressions=segments),
1067            only_json_types=arrow_req_json_type,
1068        )
1069
1070    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1073def json_extract_segments(
1074    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1075) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1076    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1077        path = expression.expression
1078        if not isinstance(path, exp.JSONPath):
1079            return rename_func(name)(self, expression)
1080
1081        segments = []
1082        for segment in path.expressions:
1083            path = self.sql(segment)
1084            if path:
1085                if isinstance(segment, exp.JSONPathPart) and (
1086                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1087                ):
1088                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1089
1090                segments.append(path)
1091
1092        if op:
1093            return f" {op} ".join([self.sql(expression.this), *segments])
1094        return self.func(name, expression.this, *segments)
1095
1096    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1099def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1100    if isinstance(expression.this, exp.JSONPathWildcard):
1101        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1102
1103    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1106def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1107    cond = expression.expression
1108    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1109        alias = cond.expressions[0]
1110        cond = cond.this
1111    elif isinstance(cond, exp.Predicate):
1112        alias = "_u"
1113    else:
1114        self.unsupported("Unsupported filter condition")
1115        return ""
1116
1117    unnest = exp.Unnest(expressions=[expression.this])
1118    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1119    return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: sqlglot.expressions.ToNumber) -> str:
1122def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
1123    return self.func(
1124        "TO_NUMBER",
1125        expression.this,
1126        expression.args.get("format"),
1127        expression.args.get("nlsparam"),
1128    )