Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import logging
   4import typing as t
   5from enum import Enum, auto
   6from functools import reduce
   7
   8from sqlglot import exp
   9from sqlglot.errors import ParseError
  10from sqlglot.generator import Generator
  11from sqlglot.helper import AutoName, flatten, is_int, seq_get
  12from sqlglot.jsonpath import parse as parse_json_path
  13from sqlglot.parser import Parser
  14from sqlglot.time import TIMEZONES, format_time
  15from sqlglot.tokens import Token, Tokenizer, TokenType
  16from sqlglot.trie import new_trie
  17
  18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
  21
  22
  23if t.TYPE_CHECKING:
  24    from sqlglot._typing import B, E, F
  25
  26logger = logging.getLogger("sqlglot")
  27
  28
  29class Dialects(str, Enum):
  30    """Dialects supported by SQLGLot."""
  31
  32    DIALECT = ""
  33
  34    BIGQUERY = "bigquery"
  35    CLICKHOUSE = "clickhouse"
  36    DATABRICKS = "databricks"
  37    DORIS = "doris"
  38    DRILL = "drill"
  39    DUCKDB = "duckdb"
  40    HIVE = "hive"
  41    MYSQL = "mysql"
  42    ORACLE = "oracle"
  43    POSTGRES = "postgres"
  44    PRESTO = "presto"
  45    PRQL = "prql"
  46    REDSHIFT = "redshift"
  47    SNOWFLAKE = "snowflake"
  48    SPARK = "spark"
  49    SPARK2 = "spark2"
  50    SQLITE = "sqlite"
  51    STARROCKS = "starrocks"
  52    TABLEAU = "tableau"
  53    TERADATA = "teradata"
  54    TRINO = "trino"
  55    TSQL = "tsql"
  56
  57
  58class NormalizationStrategy(str, AutoName):
  59    """Specifies the strategy according to which identifiers should be normalized."""
  60
  61    LOWERCASE = auto()
  62    """Unquoted identifiers are lowercased."""
  63
  64    UPPERCASE = auto()
  65    """Unquoted identifiers are uppercased."""
  66
  67    CASE_SENSITIVE = auto()
  68    """Always case-sensitive, regardless of quotes."""
  69
  70    CASE_INSENSITIVE = auto()
  71    """Always case-insensitive, regardless of quotes."""
  72
  73
  74class _Dialect(type):
  75    classes: t.Dict[str, t.Type[Dialect]] = {}
  76
  77    def __eq__(cls, other: t.Any) -> bool:
  78        if cls is other:
  79            return True
  80        if isinstance(other, str):
  81            return cls is cls.get(other)
  82        if isinstance(other, Dialect):
  83            return cls is type(other)
  84
  85        return False
  86
  87    def __hash__(cls) -> int:
  88        return hash(cls.__name__.lower())
  89
  90    @classmethod
  91    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  92        return cls.classes[key]
  93
  94    @classmethod
  95    def get(
  96        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  97    ) -> t.Optional[t.Type[Dialect]]:
  98        return cls.classes.get(key, default)
  99
 100    def __new__(cls, clsname, bases, attrs):
 101        klass = super().__new__(cls, clsname, bases, attrs)
 102        enum = Dialects.__members__.get(clsname.upper())
 103        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 104
 105        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 106        klass.FORMAT_TRIE = (
 107            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 108        )
 109        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 110        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 111
 112        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 113
 114        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 115        klass.parser_class = getattr(klass, "Parser", Parser)
 116        klass.generator_class = getattr(klass, "Generator", Generator)
 117
 118        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 119        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 120            klass.tokenizer_class._IDENTIFIERS.items()
 121        )[0]
 122
 123        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 124            return next(
 125                (
 126                    (s, e)
 127                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 128                    if t == token_type
 129                ),
 130                (None, None),
 131            )
 132
 133        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 134        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 135        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 136        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 137
 138        if enum not in ("", "bigquery"):
 139            klass.generator_class.SELECT_KINDS = ()
 140
 141        if enum not in ("", "databricks", "hive", "spark", "spark2"):
 142            modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
 143            for modifier in ("cluster", "distribute", "sort"):
 144                modifier_transforms.pop(modifier, None)
 145
 146            klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
 147
 148        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 149            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 150                TokenType.ANTI,
 151                TokenType.SEMI,
 152            }
 153
 154        return klass
 155
 156
 157class Dialect(metaclass=_Dialect):
 158    INDEX_OFFSET = 0
 159    """The base index offset for arrays."""
 160
 161    WEEK_OFFSET = 0
 162    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 163
 164    UNNEST_COLUMN_ONLY = False
 165    """Whether `UNNEST` table aliases are treated as column aliases."""
 166
 167    ALIAS_POST_TABLESAMPLE = False
 168    """Whether the table alias comes after tablesample."""
 169
 170    TABLESAMPLE_SIZE_IS_PERCENT = False
 171    """Whether a size in the table sample clause represents percentage."""
 172
 173    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 174    """Specifies the strategy according to which identifiers should be normalized."""
 175
 176    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 177    """Whether an unquoted identifier can start with a digit."""
 178
 179    DPIPE_IS_STRING_CONCAT = True
 180    """Whether the DPIPE token (`||`) is a string concatenation operator."""
 181
 182    STRICT_STRING_CONCAT = False
 183    """Whether `CONCAT`'s arguments must be strings."""
 184
 185    SUPPORTS_USER_DEFINED_TYPES = True
 186    """Whether user-defined data types are supported."""
 187
 188    SUPPORTS_SEMI_ANTI_JOIN = True
 189    """Whether `SEMI` or `ANTI` joins are supported."""
 190
 191    NORMALIZE_FUNCTIONS: bool | str = "upper"
 192    """
 193    Determines how function names are going to be normalized.
 194    Possible values:
 195        "upper" or True: Convert names to uppercase.
 196        "lower": Convert names to lowercase.
 197        False: Disables function name normalization.
 198    """
 199
 200    LOG_BASE_FIRST = True
 201    """Whether the base comes first in the `LOG` function."""
 202
 203    NULL_ORDERING = "nulls_are_small"
 204    """
 205    Default `NULL` ordering method to use if not explicitly set.
 206    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 207    """
 208
 209    TYPED_DIVISION = False
 210    """
 211    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 212    False means `a / b` is always float division.
 213    True means `a / b` is integer division if both `a` and `b` are integers.
 214    """
 215
 216    SAFE_DIVISION = False
 217    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 218
 219    CONCAT_COALESCE = False
 220    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 221
 222    DATE_FORMAT = "'%Y-%m-%d'"
 223    DATEINT_FORMAT = "'%Y%m%d'"
 224    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 225
 226    TIME_MAPPING: t.Dict[str, str] = {}
 227    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
 228
 229    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 230    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 231    FORMAT_MAPPING: t.Dict[str, str] = {}
 232    """
 233    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 234    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 235    """
 236
 237    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 238    """Mapping of an unescaped escape sequence to the corresponding character."""
 239
 240    PSEUDOCOLUMNS: t.Set[str] = set()
 241    """
 242    Columns that are auto-generated by the engine corresponding to this dialect.
 243    For example, such columns may be excluded from `SELECT *` queries.
 244    """
 245
 246    PREFER_CTE_ALIAS_COLUMN = False
 247    """
 248    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 249    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 250    any projection aliases in the subquery.
 251
 252    For example,
 253        WITH y(c) AS (
 254            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 255        ) SELECT c FROM y;
 256
 257        will be rewritten as
 258
 259        WITH y(c) AS (
 260            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 261        ) SELECT c FROM y;
 262    """
 263
 264    # --- Autofilled ---
 265
 266    tokenizer_class = Tokenizer
 267    parser_class = Parser
 268    generator_class = Generator
 269
 270    # A trie of the time_mapping keys
 271    TIME_TRIE: t.Dict = {}
 272    FORMAT_TRIE: t.Dict = {}
 273
 274    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 275    INVERSE_TIME_TRIE: t.Dict = {}
 276
 277    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 278
 279    # Delimiters for string literals and identifiers
 280    QUOTE_START = "'"
 281    QUOTE_END = "'"
 282    IDENTIFIER_START = '"'
 283    IDENTIFIER_END = '"'
 284
 285    # Delimiters for bit, hex, byte and unicode literals
 286    BIT_START: t.Optional[str] = None
 287    BIT_END: t.Optional[str] = None
 288    HEX_START: t.Optional[str] = None
 289    HEX_END: t.Optional[str] = None
 290    BYTE_START: t.Optional[str] = None
 291    BYTE_END: t.Optional[str] = None
 292    UNICODE_START: t.Optional[str] = None
 293    UNICODE_END: t.Optional[str] = None
 294
 295    @classmethod
 296    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 297        """
 298        Look up a dialect in the global dialect registry and return it if it exists.
 299
 300        Args:
 301            dialect: The target dialect. If this is a string, it can be optionally followed by
 302                additional key-value pairs that are separated by commas and are used to specify
 303                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 304
 305        Example:
 306            >>> dialect = dialect_class = get_or_raise("duckdb")
 307            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 308
 309        Returns:
 310            The corresponding Dialect instance.
 311        """
 312
 313        if not dialect:
 314            return cls()
 315        if isinstance(dialect, _Dialect):
 316            return dialect()
 317        if isinstance(dialect, Dialect):
 318            return dialect
 319        if isinstance(dialect, str):
 320            try:
 321                dialect_name, *kv_pairs = dialect.split(",")
 322                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 323            except ValueError:
 324                raise ValueError(
 325                    f"Invalid dialect format: '{dialect}'. "
 326                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 327                )
 328
 329            result = cls.get(dialect_name.strip())
 330            if not result:
 331                from difflib import get_close_matches
 332
 333                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 334                if similar:
 335                    similar = f" Did you mean {similar}?"
 336
 337                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 338
 339            return result(**kwargs)
 340
 341        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 342
 343    @classmethod
 344    def format_time(
 345        cls, expression: t.Optional[str | exp.Expression]
 346    ) -> t.Optional[exp.Expression]:
 347        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 348        if isinstance(expression, str):
 349            return exp.Literal.string(
 350                # the time formats are quoted
 351                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 352            )
 353
 354        if expression and expression.is_string:
 355            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 356
 357        return expression
 358
 359    def __init__(self, **kwargs) -> None:
 360        normalization_strategy = kwargs.get("normalization_strategy")
 361
 362        if normalization_strategy is None:
 363            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 364        else:
 365            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 366
 367    def __eq__(self, other: t.Any) -> bool:
 368        # Does not currently take dialect state into account
 369        return type(self) == other
 370
 371    def __hash__(self) -> int:
 372        # Does not currently take dialect state into account
 373        return hash(type(self))
 374
 375    def normalize_identifier(self, expression: E) -> E:
 376        """
 377        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 378
 379        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 380        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 381        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 382        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 383
 384        There are also dialects like Spark, which are case-insensitive even when quotes are
 385        present, and dialects like MySQL, whose resolution rules match those employed by the
 386        underlying operating system, for example they may always be case-sensitive in Linux.
 387
 388        Finally, the normalization behavior of some engines can even be controlled through flags,
 389        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 390
 391        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 392        that it can analyze queries in the optimizer and successfully capture their semantics.
 393        """
 394        if (
 395            isinstance(expression, exp.Identifier)
 396            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
 397            and (
 398                not expression.quoted
 399                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 400            )
 401        ):
 402            expression.set(
 403                "this",
 404                (
 405                    expression.this.upper()
 406                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 407                    else expression.this.lower()
 408                ),
 409            )
 410
 411        return expression
 412
 413    def case_sensitive(self, text: str) -> bool:
 414        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 415        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 416            return False
 417
 418        unsafe = (
 419            str.islower
 420            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 421            else str.isupper
 422        )
 423        return any(unsafe(char) for char in text)
 424
 425    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 426        """Checks if text can be identified given an identify option.
 427
 428        Args:
 429            text: The text to check.
 430            identify:
 431                `"always"` or `True`: Always returns `True`.
 432                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 433
 434        Returns:
 435            Whether the given text can be identified.
 436        """
 437        if identify is True or identify == "always":
 438            return True
 439
 440        if identify == "safe":
 441            return not self.case_sensitive(text)
 442
 443        return False
 444
 445    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 446        """
 447        Adds quotes to a given identifier.
 448
 449        Args:
 450            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 451            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 452                "unsafe", with respect to its characters and this dialect's normalization strategy.
 453        """
 454        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
 455            name = expression.this
 456            expression.set(
 457                "quoted",
 458                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 459            )
 460
 461        return expression
 462
 463    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
 464        if isinstance(path, exp.Literal):
 465            path_text = path.name
 466            if path.is_number:
 467                path_text = f"[{path_text}]"
 468
 469            try:
 470                return parse_json_path(path_text)
 471            except ParseError as e:
 472                logger.warning(f"Invalid JSON path syntax. {str(e)}")
 473
 474        return path
 475
 476    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 477        return self.parser(**opts).parse(self.tokenize(sql), sql)
 478
 479    def parse_into(
 480        self, expression_type: exp.IntoType, sql: str, **opts
 481    ) -> t.List[t.Optional[exp.Expression]]:
 482        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 483
 484    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 485        return self.generator(**opts).generate(expression, copy=copy)
 486
 487    def transpile(self, sql: str, **opts) -> t.List[str]:
 488        return [
 489            self.generate(expression, copy=False, **opts) if expression else ""
 490            for expression in self.parse(sql)
 491        ]
 492
 493    def tokenize(self, sql: str) -> t.List[Token]:
 494        return self.tokenizer.tokenize(sql)
 495
 496    @property
 497    def tokenizer(self) -> Tokenizer:
 498        if not hasattr(self, "_tokenizer"):
 499            self._tokenizer = self.tokenizer_class(dialect=self)
 500        return self._tokenizer
 501
 502    def parser(self, **opts) -> Parser:
 503        return self.parser_class(dialect=self, **opts)
 504
 505    def generator(self, **opts) -> Generator:
 506        return self.generator_class(dialect=self, **opts)
 507
 508
 509DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 510
 511
 512def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 513    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 514
 515
 516def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 517    if expression.args.get("accuracy"):
 518        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 519    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 520
 521
 522def if_sql(
 523    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 524) -> t.Callable[[Generator, exp.If], str]:
 525    def _if_sql(self: Generator, expression: exp.If) -> str:
 526        return self.func(
 527            name,
 528            expression.this,
 529            expression.args.get("true"),
 530            expression.args.get("false") or false_value,
 531        )
 532
 533    return _if_sql
 534
 535
 536def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
 537    this = expression.this
 538    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
 539        this.replace(exp.cast(this, "json"))
 540
 541    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
 542
 543
 544def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 545    return f"[{self.expressions(expression, flat=True)}]"
 546
 547
 548def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 549    return self.like_sql(
 550        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 551    )
 552
 553
 554def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 555    zone = self.sql(expression, "this")
 556    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 557
 558
 559def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 560    if expression.args.get("recursive"):
 561        self.unsupported("Recursive CTEs are unsupported")
 562        expression.args["recursive"] = False
 563    return self.with_sql(expression)
 564
 565
 566def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 567    n = self.sql(expression, "this")
 568    d = self.sql(expression, "expression")
 569    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 570
 571
 572def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 573    self.unsupported("TABLESAMPLE unsupported")
 574    return self.sql(expression.this)
 575
 576
 577def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 578    self.unsupported("PIVOT unsupported")
 579    return ""
 580
 581
 582def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 583    return self.cast_sql(expression)
 584
 585
 586def no_comment_column_constraint_sql(
 587    self: Generator, expression: exp.CommentColumnConstraint
 588) -> str:
 589    self.unsupported("CommentColumnConstraint unsupported")
 590    return ""
 591
 592
 593def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 594    self.unsupported("MAP_FROM_ENTRIES unsupported")
 595    return ""
 596
 597
 598def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 599    this = self.sql(expression, "this")
 600    substr = self.sql(expression, "substr")
 601    position = self.sql(expression, "position")
 602    if position:
 603        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 604    return f"STRPOS({this}, {substr})"
 605
 606
 607def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 608    return (
 609        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 610    )
 611
 612
 613def var_map_sql(
 614    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 615) -> str:
 616    keys = expression.args["keys"]
 617    values = expression.args["values"]
 618
 619    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 620        self.unsupported("Cannot convert array columns into map.")
 621        return self.func(map_func_name, keys, values)
 622
 623    args = []
 624    for key, value in zip(keys.expressions, values.expressions):
 625        args.append(self.sql(key))
 626        args.append(self.sql(value))
 627
 628    return self.func(map_func_name, *args)
 629
 630
 631def build_formatted_time(
 632    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 633) -> t.Callable[[t.List], E]:
 634    """Helper used for time expressions.
 635
 636    Args:
 637        exp_class: the expression class to instantiate.
 638        dialect: target sql dialect.
 639        default: the default format, True being time.
 640
 641    Returns:
 642        A callable that can be used to return the appropriately formatted time expression.
 643    """
 644
 645    def _builder(args: t.List):
 646        return exp_class(
 647            this=seq_get(args, 0),
 648            format=Dialect[dialect].format_time(
 649                seq_get(args, 1)
 650                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 651            ),
 652        )
 653
 654    return _builder
 655
 656
 657def time_format(
 658    dialect: DialectType = None,
 659) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 660    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 661        """
 662        Returns the time format for a given expression, unless it's equivalent
 663        to the default time format of the dialect of interest.
 664        """
 665        time_format = self.format_time(expression)
 666        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 667
 668    return _time_format
 669
 670
 671def build_date_delta(
 672    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 673) -> t.Callable[[t.List], E]:
 674    def _builder(args: t.List) -> E:
 675        unit_based = len(args) == 3
 676        this = args[2] if unit_based else seq_get(args, 0)
 677        unit = args[0] if unit_based else exp.Literal.string("DAY")
 678        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 679        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 680
 681    return _builder
 682
 683
 684def build_date_delta_with_interval(
 685    expression_class: t.Type[E],
 686) -> t.Callable[[t.List], t.Optional[E]]:
 687    def _builder(args: t.List) -> t.Optional[E]:
 688        if len(args) < 2:
 689            return None
 690
 691        interval = args[1]
 692
 693        if not isinstance(interval, exp.Interval):
 694            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 695
 696        expression = interval.this
 697        if expression and expression.is_string:
 698            expression = exp.Literal.number(expression.this)
 699
 700        return expression_class(
 701            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 702        )
 703
 704    return _builder
 705
 706
 707def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 708    unit = seq_get(args, 0)
 709    this = seq_get(args, 1)
 710
 711    if isinstance(this, exp.Cast) and this.is_type("date"):
 712        return exp.DateTrunc(unit=unit, this=this)
 713    return exp.TimestampTrunc(this=this, unit=unit)
 714
 715
 716def date_add_interval_sql(
 717    data_type: str, kind: str
 718) -> t.Callable[[Generator, exp.Expression], str]:
 719    def func(self: Generator, expression: exp.Expression) -> str:
 720        this = self.sql(expression, "this")
 721        unit = expression.args.get("unit")
 722        unit = exp.var(unit.name.upper() if unit else "DAY")
 723        interval = exp.Interval(this=expression.expression, unit=unit)
 724        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 725
 726    return func
 727
 728
 729def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 730    return self.func(
 731        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 732    )
 733
 734
 735def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 736    if not expression.expression:
 737        from sqlglot.optimizer.annotate_types import annotate_types
 738
 739        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
 740        return self.sql(exp.cast(expression.this, to=target_type))
 741    if expression.text("expression").lower() in TIMEZONES:
 742        return self.sql(
 743            exp.AtTimeZone(
 744                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 745                zone=expression.expression,
 746            )
 747        )
 748    return self.func("TIMESTAMP", expression.this, expression.expression)
 749
 750
 751def locate_to_strposition(args: t.List) -> exp.Expression:
 752    return exp.StrPosition(
 753        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 754    )
 755
 756
 757def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 758    return self.func(
 759        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 760    )
 761
 762
 763def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 764    return self.sql(
 765        exp.Substring(
 766            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 767        )
 768    )
 769
 770
 771def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 772    return self.sql(
 773        exp.Substring(
 774            this=expression.this,
 775            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 776        )
 777    )
 778
 779
 780def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 781    return self.sql(exp.cast(expression.this, "timestamp"))
 782
 783
 784def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 785    return self.sql(exp.cast(expression.this, "date"))
 786
 787
 788# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 789def encode_decode_sql(
 790    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 791) -> str:
 792    charset = expression.args.get("charset")
 793    if charset and charset.name.lower() != "utf-8":
 794        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 795
 796    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 797
 798
 799def min_or_least(self: Generator, expression: exp.Min) -> str:
 800    name = "LEAST" if expression.expressions else "MIN"
 801    return rename_func(name)(self, expression)
 802
 803
 804def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 805    name = "GREATEST" if expression.expressions else "MAX"
 806    return rename_func(name)(self, expression)
 807
 808
 809def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 810    cond = expression.this
 811
 812    if isinstance(expression.this, exp.Distinct):
 813        cond = expression.this.expressions[0]
 814        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 815
 816    return self.func("sum", exp.func("if", cond, 1, 0))
 817
 818
 819def trim_sql(self: Generator, expression: exp.Trim) -> str:
 820    target = self.sql(expression, "this")
 821    trim_type = self.sql(expression, "position")
 822    remove_chars = self.sql(expression, "expression")
 823    collation = self.sql(expression, "collation")
 824
 825    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 826    if not remove_chars and not collation:
 827        return self.trim_sql(expression)
 828
 829    trim_type = f"{trim_type} " if trim_type else ""
 830    remove_chars = f"{remove_chars} " if remove_chars else ""
 831    from_part = "FROM " if trim_type or remove_chars else ""
 832    collation = f" COLLATE {collation}" if collation else ""
 833    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 834
 835
 836def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 837    return self.func("STRPTIME", expression.this, self.format_time(expression))
 838
 839
 840def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 841    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 842
 843
 844def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 845    delim, *rest_args = expression.expressions
 846    return self.sql(
 847        reduce(
 848            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 849            rest_args,
 850        )
 851    )
 852
 853
 854def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 855    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 856    if bad_args:
 857        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 858
 859    return self.func(
 860        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 861    )
 862
 863
 864def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 865    bad_args = list(
 866        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 867    )
 868    if bad_args:
 869        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 870
 871    return self.func(
 872        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 873    )
 874
 875
 876def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 877    names = []
 878    for agg in aggregations:
 879        if isinstance(agg, exp.Alias):
 880            names.append(agg.alias)
 881        else:
 882            """
 883            This case corresponds to aggregations without aliases being used as suffixes
 884            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 885            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 886            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 887            """
 888            agg_all_unquoted = agg.transform(
 889                lambda node: (
 890                    exp.Identifier(this=node.name, quoted=False)
 891                    if isinstance(node, exp.Identifier)
 892                    else node
 893                )
 894            )
 895            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 896
 897    return names
 898
 899
 900def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 901    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 902
 903
 904# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 905def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 906    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 907
 908
 909def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 910    return self.func("MAX", expression.this)
 911
 912
 913def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 914    a = self.sql(expression.left)
 915    b = self.sql(expression.right)
 916    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 917
 918
 919def is_parse_json(expression: exp.Expression) -> bool:
 920    return isinstance(expression, exp.ParseJSON) or (
 921        isinstance(expression, exp.Cast) and expression.is_type("json")
 922    )
 923
 924
 925def isnull_to_is_null(args: t.List) -> exp.Expression:
 926    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 927
 928
 929def generatedasidentitycolumnconstraint_sql(
 930    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 931) -> str:
 932    start = self.sql(expression, "start") or "1"
 933    increment = self.sql(expression, "increment") or "1"
 934    return f"IDENTITY({start}, {increment})"
 935
 936
 937def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 938    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 939        if expression.args.get("count"):
 940            self.unsupported(f"Only two arguments are supported in function {name}.")
 941
 942        return self.func(name, expression.this, expression.expression)
 943
 944    return _arg_max_or_min_sql
 945
 946
 947def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 948    this = expression.this.copy()
 949
 950    return_type = expression.return_type
 951    if return_type.is_type(exp.DataType.Type.DATE):
 952        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 953        # can truncate timestamp strings, because some dialects can't cast them to DATE
 954        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 955
 956    expression.this.replace(exp.cast(this, return_type))
 957    return expression
 958
 959
 960def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 961    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 962        if cast and isinstance(expression, exp.TsOrDsAdd):
 963            expression = ts_or_ds_add_cast(expression)
 964
 965        return self.func(
 966            name,
 967            exp.var(expression.text("unit").upper() or "DAY"),
 968            expression.expression,
 969            expression.this,
 970        )
 971
 972    return _delta_sql
 973
 974
 975def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 976    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 977    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 978    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 979
 980    return self.sql(exp.cast(minus_one_day, "date"))
 981
 982
 983def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 984    """Remove table refs from columns in when statements."""
 985    alias = expression.this.args.get("alias")
 986
 987    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 988        return self.dialect.normalize_identifier(identifier).name if identifier else None
 989
 990    targets = {normalize(expression.this.this)}
 991
 992    if alias:
 993        targets.add(normalize(alias.this))
 994
 995    for when in expression.expressions:
 996        when.transform(
 997            lambda node: (
 998                exp.column(node.this)
 999                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1000                else node
1001            ),
1002            copy=False,
1003        )
1004
1005    return self.merge_sql(expression)
1006
1007
1008def build_json_extract_path(
1009    expr_type: t.Type[F], zero_based_indexing: bool = True
1010) -> t.Callable[[t.List], F]:
1011    def _builder(args: t.List) -> F:
1012        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1013        for arg in args[1:]:
1014            if not isinstance(arg, exp.Literal):
1015                # We use the fallback parser because we can't really transpile non-literals safely
1016                return expr_type.from_arg_list(args)
1017
1018            text = arg.name
1019            if is_int(text):
1020                index = int(text)
1021                segments.append(
1022                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1023                )
1024            else:
1025                segments.append(exp.JSONPathKey(this=text))
1026
1027        # This is done to avoid failing in the expression validator due to the arg count
1028        del args[2:]
1029        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1030
1031    return _builder
1032
1033
1034def json_extract_segments(
1035    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1036) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1037    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1038        path = expression.expression
1039        if not isinstance(path, exp.JSONPath):
1040            return rename_func(name)(self, expression)
1041
1042        segments = []
1043        for segment in path.expressions:
1044            path = self.sql(segment)
1045            if path:
1046                if isinstance(segment, exp.JSONPathPart) and (
1047                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1048                ):
1049                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1050
1051                segments.append(path)
1052
1053        if op:
1054            return f" {op} ".join([self.sql(expression.this), *segments])
1055        return self.func(name, expression.this, *segments)
1056
1057    return _json_extract_segments
1058
1059
1060def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1061    if isinstance(expression.this, exp.JSONPathWildcard):
1062        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1063
1064    return expression.name
1065
1066
1067def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1068    cond = expression.expression
1069    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1070        alias = cond.expressions[0]
1071        cond = cond.this
1072    elif isinstance(cond, exp.Predicate):
1073        alias = "_u"
1074    else:
1075        self.unsupported("Unsupported filter condition")
1076        return ""
1077
1078    unnest = exp.Unnest(expressions=[expression.this])
1079    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1080    return self.sql(exp.Array(expressions=[filtered]))
logger = <Logger sqlglot (WARNING)>
class Dialects(builtins.str, enum.Enum):
30class Dialects(str, Enum):
31    """Dialects supported by SQLGLot."""
32
33    DIALECT = ""
34
35    BIGQUERY = "bigquery"
36    CLICKHOUSE = "clickhouse"
37    DATABRICKS = "databricks"
38    DORIS = "doris"
39    DRILL = "drill"
40    DUCKDB = "duckdb"
41    HIVE = "hive"
42    MYSQL = "mysql"
43    ORACLE = "oracle"
44    POSTGRES = "postgres"
45    PRESTO = "presto"
46    PRQL = "prql"
47    REDSHIFT = "redshift"
48    SNOWFLAKE = "snowflake"
49    SPARK = "spark"
50    SPARK2 = "spark2"
51    SQLITE = "sqlite"
52    STARROCKS = "starrocks"
53    TABLEAU = "tableau"
54    TERADATA = "teradata"
55    TRINO = "trino"
56    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
PRQL = <Dialects.PRQL: 'prql'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
59class NormalizationStrategy(str, AutoName):
60    """Specifies the strategy according to which identifiers should be normalized."""
61
62    LOWERCASE = auto()
63    """Unquoted identifiers are lowercased."""
64
65    UPPERCASE = auto()
66    """Unquoted identifiers are uppercased."""
67
68    CASE_SENSITIVE = auto()
69    """Always case-sensitive, regardless of quotes."""
70
71    CASE_INSENSITIVE = auto()
72    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
158class Dialect(metaclass=_Dialect):
159    INDEX_OFFSET = 0
160    """The base index offset for arrays."""
161
162    WEEK_OFFSET = 0
163    """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
164
165    UNNEST_COLUMN_ONLY = False
166    """Whether `UNNEST` table aliases are treated as column aliases."""
167
168    ALIAS_POST_TABLESAMPLE = False
169    """Whether the table alias comes after tablesample."""
170
171    TABLESAMPLE_SIZE_IS_PERCENT = False
172    """Whether a size in the table sample clause represents percentage."""
173
174    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
175    """Specifies the strategy according to which identifiers should be normalized."""
176
177    IDENTIFIERS_CAN_START_WITH_DIGIT = False
178    """Whether an unquoted identifier can start with a digit."""
179
180    DPIPE_IS_STRING_CONCAT = True
181    """Whether the DPIPE token (`||`) is a string concatenation operator."""
182
183    STRICT_STRING_CONCAT = False
184    """Whether `CONCAT`'s arguments must be strings."""
185
186    SUPPORTS_USER_DEFINED_TYPES = True
187    """Whether user-defined data types are supported."""
188
189    SUPPORTS_SEMI_ANTI_JOIN = True
190    """Whether `SEMI` or `ANTI` joins are supported."""
191
192    NORMALIZE_FUNCTIONS: bool | str = "upper"
193    """
194    Determines how function names are going to be normalized.
195    Possible values:
196        "upper" or True: Convert names to uppercase.
197        "lower": Convert names to lowercase.
198        False: Disables function name normalization.
199    """
200
201    LOG_BASE_FIRST = True
202    """Whether the base comes first in the `LOG` function."""
203
204    NULL_ORDERING = "nulls_are_small"
205    """
206    Default `NULL` ordering method to use if not explicitly set.
207    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
208    """
209
210    TYPED_DIVISION = False
211    """
212    Whether the behavior of `a / b` depends on the types of `a` and `b`.
213    False means `a / b` is always float division.
214    True means `a / b` is integer division if both `a` and `b` are integers.
215    """
216
217    SAFE_DIVISION = False
218    """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
219
220    CONCAT_COALESCE = False
221    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
222
223    DATE_FORMAT = "'%Y-%m-%d'"
224    DATEINT_FORMAT = "'%Y%m%d'"
225    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
226
227    TIME_MAPPING: t.Dict[str, str] = {}
228    """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
229
230    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
231    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
232    FORMAT_MAPPING: t.Dict[str, str] = {}
233    """
234    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
235    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
236    """
237
238    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
239    """Mapping of an unescaped escape sequence to the corresponding character."""
240
241    PSEUDOCOLUMNS: t.Set[str] = set()
242    """
243    Columns that are auto-generated by the engine corresponding to this dialect.
244    For example, such columns may be excluded from `SELECT *` queries.
245    """
246
247    PREFER_CTE_ALIAS_COLUMN = False
248    """
249    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
250    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
251    any projection aliases in the subquery.
252
253    For example,
254        WITH y(c) AS (
255            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
256        ) SELECT c FROM y;
257
258        will be rewritten as
259
260        WITH y(c) AS (
261            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
262        ) SELECT c FROM y;
263    """
264
265    # --- Autofilled ---
266
267    tokenizer_class = Tokenizer
268    parser_class = Parser
269    generator_class = Generator
270
271    # A trie of the time_mapping keys
272    TIME_TRIE: t.Dict = {}
273    FORMAT_TRIE: t.Dict = {}
274
275    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
276    INVERSE_TIME_TRIE: t.Dict = {}
277
278    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
279
280    # Delimiters for string literals and identifiers
281    QUOTE_START = "'"
282    QUOTE_END = "'"
283    IDENTIFIER_START = '"'
284    IDENTIFIER_END = '"'
285
286    # Delimiters for bit, hex, byte and unicode literals
287    BIT_START: t.Optional[str] = None
288    BIT_END: t.Optional[str] = None
289    HEX_START: t.Optional[str] = None
290    HEX_END: t.Optional[str] = None
291    BYTE_START: t.Optional[str] = None
292    BYTE_END: t.Optional[str] = None
293    UNICODE_START: t.Optional[str] = None
294    UNICODE_END: t.Optional[str] = None
295
296    @classmethod
297    def get_or_raise(cls, dialect: DialectType) -> Dialect:
298        """
299        Look up a dialect in the global dialect registry and return it if it exists.
300
301        Args:
302            dialect: The target dialect. If this is a string, it can be optionally followed by
303                additional key-value pairs that are separated by commas and are used to specify
304                dialect settings, such as whether the dialect's identifiers are case-sensitive.
305
306        Example:
307            >>> dialect = dialect_class = get_or_raise("duckdb")
308            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
309
310        Returns:
311            The corresponding Dialect instance.
312        """
313
314        if not dialect:
315            return cls()
316        if isinstance(dialect, _Dialect):
317            return dialect()
318        if isinstance(dialect, Dialect):
319            return dialect
320        if isinstance(dialect, str):
321            try:
322                dialect_name, *kv_pairs = dialect.split(",")
323                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
324            except ValueError:
325                raise ValueError(
326                    f"Invalid dialect format: '{dialect}'. "
327                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
328                )
329
330            result = cls.get(dialect_name.strip())
331            if not result:
332                from difflib import get_close_matches
333
334                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
335                if similar:
336                    similar = f" Did you mean {similar}?"
337
338                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
339
340            return result(**kwargs)
341
342        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
343
344    @classmethod
345    def format_time(
346        cls, expression: t.Optional[str | exp.Expression]
347    ) -> t.Optional[exp.Expression]:
348        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
349        if isinstance(expression, str):
350            return exp.Literal.string(
351                # the time formats are quoted
352                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
353            )
354
355        if expression and expression.is_string:
356            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
357
358        return expression
359
360    def __init__(self, **kwargs) -> None:
361        normalization_strategy = kwargs.get("normalization_strategy")
362
363        if normalization_strategy is None:
364            self.normalization_strategy = self.NORMALIZATION_STRATEGY
365        else:
366            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
367
368    def __eq__(self, other: t.Any) -> bool:
369        # Does not currently take dialect state into account
370        return type(self) == other
371
372    def __hash__(self) -> int:
373        # Does not currently take dialect state into account
374        return hash(type(self))
375
376    def normalize_identifier(self, expression: E) -> E:
377        """
378        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
379
380        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
381        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
382        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
383        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
384
385        There are also dialects like Spark, which are case-insensitive even when quotes are
386        present, and dialects like MySQL, whose resolution rules match those employed by the
387        underlying operating system, for example they may always be case-sensitive in Linux.
388
389        Finally, the normalization behavior of some engines can even be controlled through flags,
390        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
391
392        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
393        that it can analyze queries in the optimizer and successfully capture their semantics.
394        """
395        if (
396            isinstance(expression, exp.Identifier)
397            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
398            and (
399                not expression.quoted
400                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
401            )
402        ):
403            expression.set(
404                "this",
405                (
406                    expression.this.upper()
407                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
408                    else expression.this.lower()
409                ),
410            )
411
412        return expression
413
414    def case_sensitive(self, text: str) -> bool:
415        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
416        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
417            return False
418
419        unsafe = (
420            str.islower
421            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
422            else str.isupper
423        )
424        return any(unsafe(char) for char in text)
425
426    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
427        """Checks if text can be identified given an identify option.
428
429        Args:
430            text: The text to check.
431            identify:
432                `"always"` or `True`: Always returns `True`.
433                `"safe"`: Only returns `True` if the identifier is case-insensitive.
434
435        Returns:
436            Whether the given text can be identified.
437        """
438        if identify is True or identify == "always":
439            return True
440
441        if identify == "safe":
442            return not self.case_sensitive(text)
443
444        return False
445
446    def quote_identifier(self, expression: E, identify: bool = True) -> E:
447        """
448        Adds quotes to a given identifier.
449
450        Args:
451            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
452            identify: If set to `False`, the quotes will only be added if the identifier is deemed
453                "unsafe", with respect to its characters and this dialect's normalization strategy.
454        """
455        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
456            name = expression.this
457            expression.set(
458                "quoted",
459                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
460            )
461
462        return expression
463
464    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
465        if isinstance(path, exp.Literal):
466            path_text = path.name
467            if path.is_number:
468                path_text = f"[{path_text}]"
469
470            try:
471                return parse_json_path(path_text)
472            except ParseError as e:
473                logger.warning(f"Invalid JSON path syntax. {str(e)}")
474
475        return path
476
477    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
478        return self.parser(**opts).parse(self.tokenize(sql), sql)
479
480    def parse_into(
481        self, expression_type: exp.IntoType, sql: str, **opts
482    ) -> t.List[t.Optional[exp.Expression]]:
483        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
484
485    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
486        return self.generator(**opts).generate(expression, copy=copy)
487
488    def transpile(self, sql: str, **opts) -> t.List[str]:
489        return [
490            self.generate(expression, copy=False, **opts) if expression else ""
491            for expression in self.parse(sql)
492        ]
493
494    def tokenize(self, sql: str) -> t.List[Token]:
495        return self.tokenizer.tokenize(sql)
496
497    @property
498    def tokenizer(self) -> Tokenizer:
499        if not hasattr(self, "_tokenizer"):
500            self._tokenizer = self.tokenizer_class(dialect=self)
501        return self._tokenizer
502
503    def parser(self, **opts) -> Parser:
504        return self.parser_class(dialect=self, **opts)
505
506    def generator(self, **opts) -> Generator:
507        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
360    def __init__(self, **kwargs) -> None:
361        normalization_strategy = kwargs.get("normalization_strategy")
362
363        if normalization_strategy is None:
364            self.normalization_strategy = self.NORMALIZATION_STRATEGY
365        else:
366            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

The base index offset for arrays.

WEEK_OFFSET = 0

First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Whether UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Whether the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Whether a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Whether an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Whether the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Whether CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Whether user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Whether SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

Possible values:

"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.

LOG_BASE_FIRST = True

Whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime formats.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
296    @classmethod
297    def get_or_raise(cls, dialect: DialectType) -> Dialect:
298        """
299        Look up a dialect in the global dialect registry and return it if it exists.
300
301        Args:
302            dialect: The target dialect. If this is a string, it can be optionally followed by
303                additional key-value pairs that are separated by commas and are used to specify
304                dialect settings, such as whether the dialect's identifiers are case-sensitive.
305
306        Example:
307            >>> dialect = dialect_class = get_or_raise("duckdb")
308            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
309
310        Returns:
311            The corresponding Dialect instance.
312        """
313
314        if not dialect:
315            return cls()
316        if isinstance(dialect, _Dialect):
317            return dialect()
318        if isinstance(dialect, Dialect):
319            return dialect
320        if isinstance(dialect, str):
321            try:
322                dialect_name, *kv_pairs = dialect.split(",")
323                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
324            except ValueError:
325                raise ValueError(
326                    f"Invalid dialect format: '{dialect}'. "
327                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
328                )
329
330            result = cls.get(dialect_name.strip())
331            if not result:
332                from difflib import get_close_matches
333
334                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
335                if similar:
336                    similar = f" Did you mean {similar}?"
337
338                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
339
340            return result(**kwargs)
341
342        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
344    @classmethod
345    def format_time(
346        cls, expression: t.Optional[str | exp.Expression]
347    ) -> t.Optional[exp.Expression]:
348        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
349        if isinstance(expression, str):
350            return exp.Literal.string(
351                # the time formats are quoted
352                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
353            )
354
355        if expression and expression.is_string:
356            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
357
358        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
376    def normalize_identifier(self, expression: E) -> E:
377        """
378        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
379
380        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
381        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
382        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
383        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
384
385        There are also dialects like Spark, which are case-insensitive even when quotes are
386        present, and dialects like MySQL, whose resolution rules match those employed by the
387        underlying operating system, for example they may always be case-sensitive in Linux.
388
389        Finally, the normalization behavior of some engines can even be controlled through flags,
390        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
391
392        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
393        that it can analyze queries in the optimizer and successfully capture their semantics.
394        """
395        if (
396            isinstance(expression, exp.Identifier)
397            and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
398            and (
399                not expression.quoted
400                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
401            )
402        ):
403            expression.set(
404                "this",
405                (
406                    expression.this.upper()
407                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
408                    else expression.this.lower()
409                ),
410            )
411
412        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
414    def case_sensitive(self, text: str) -> bool:
415        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
416        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
417            return False
418
419        unsafe = (
420            str.islower
421            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
422            else str.isupper
423        )
424        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
426    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
427        """Checks if text can be identified given an identify option.
428
429        Args:
430            text: The text to check.
431            identify:
432                `"always"` or `True`: Always returns `True`.
433                `"safe"`: Only returns `True` if the identifier is case-insensitive.
434
435        Returns:
436            Whether the given text can be identified.
437        """
438        if identify is True or identify == "always":
439            return True
440
441        if identify == "safe":
442            return not self.case_sensitive(text)
443
444        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
446    def quote_identifier(self, expression: E, identify: bool = True) -> E:
447        """
448        Adds quotes to a given identifier.
449
450        Args:
451            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
452            identify: If set to `False`, the quotes will only be added if the identifier is deemed
453                "unsafe", with respect to its characters and this dialect's normalization strategy.
454        """
455        if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
456            name = expression.this
457            expression.set(
458                "quoted",
459                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
460            )
461
462        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def to_json_path( self, path: Optional[sqlglot.expressions.Expression]) -> Optional[sqlglot.expressions.Expression]:
464    def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
465        if isinstance(path, exp.Literal):
466            path_text = path.name
467            if path.is_number:
468                path_text = f"[{path_text}]"
469
470            try:
471                return parse_json_path(path_text)
472            except ParseError as e:
473                logger.warning(f"Invalid JSON path syntax. {str(e)}")
474
475        return path
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
477    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
478        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
480    def parse_into(
481        self, expression_type: exp.IntoType, sql: str, **opts
482    ) -> t.List[t.Optional[exp.Expression]]:
483        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
485    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
486        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
488    def transpile(self, sql: str, **opts) -> t.List[str]:
489        return [
490            self.generate(expression, copy=False, **opts) if expression else ""
491            for expression in self.parse(sql)
492        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
494    def tokenize(self, sql: str) -> t.List[Token]:
495        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
497    @property
498    def tokenizer(self) -> Tokenizer:
499        if not hasattr(self, "_tokenizer"):
500            self._tokenizer = self.tokenizer_class(dialect=self)
501        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
503    def parser(self, **opts) -> Parser:
504        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
506    def generator(self, **opts) -> Generator:
507        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
513def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
514    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
517def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
518    if expression.args.get("accuracy"):
519        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
520    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
523def if_sql(
524    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
525) -> t.Callable[[Generator, exp.If], str]:
526    def _if_sql(self: Generator, expression: exp.If) -> str:
527        return self.func(
528            name,
529            expression.this,
530            expression.args.get("true"),
531            expression.args.get("false") or false_value,
532        )
533
534    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]) -> str:
537def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
538    this = expression.this
539    if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
540        this.replace(exp.cast(this, "json"))
541
542    return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
545def inline_array_sql(self: Generator, expression: exp.Array) -> str:
546    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
549def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
550    return self.like_sql(
551        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
552    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
555def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
556    zone = self.sql(expression, "this")
557    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
560def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
561    if expression.args.get("recursive"):
562        self.unsupported("Recursive CTEs are unsupported")
563        expression.args["recursive"] = False
564    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
567def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
568    n = self.sql(expression, "this")
569    d = self.sql(expression, "expression")
570    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
573def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
574    self.unsupported("TABLESAMPLE unsupported")
575    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
578def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
579    self.unsupported("PIVOT unsupported")
580    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
583def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
584    return self.cast_sql(expression)
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
587def no_comment_column_constraint_sql(
588    self: Generator, expression: exp.CommentColumnConstraint
589) -> str:
590    self.unsupported("CommentColumnConstraint unsupported")
591    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
594def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
595    self.unsupported("MAP_FROM_ENTRIES unsupported")
596    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
599def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
600    this = self.sql(expression, "this")
601    substr = self.sql(expression, "substr")
602    position = self.sql(expression, "position")
603    if position:
604        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
605    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
608def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
609    return (
610        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
611    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
614def var_map_sql(
615    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
616) -> str:
617    keys = expression.args["keys"]
618    values = expression.args["values"]
619
620    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
621        self.unsupported("Cannot convert array columns into map.")
622        return self.func(map_func_name, keys, values)
623
624    args = []
625    for key, value in zip(keys.expressions, values.expressions):
626        args.append(self.sql(key))
627        args.append(self.sql(value))
628
629    return self.func(map_func_name, *args)
def build_formatted_time( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
632def build_formatted_time(
633    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
634) -> t.Callable[[t.List], E]:
635    """Helper used for time expressions.
636
637    Args:
638        exp_class: the expression class to instantiate.
639        dialect: target sql dialect.
640        default: the default format, True being time.
641
642    Returns:
643        A callable that can be used to return the appropriately formatted time expression.
644    """
645
646    def _builder(args: t.List):
647        return exp_class(
648            this=seq_get(args, 0),
649            format=Dialect[dialect].format_time(
650                seq_get(args, 1)
651                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
652            ),
653        )
654
655    return _builder

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
658def time_format(
659    dialect: DialectType = None,
660) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
661    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
662        """
663        Returns the time format for a given expression, unless it's equivalent
664        to the default time format of the dialect of interest.
665        """
666        time_format = self.format_time(expression)
667        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
668
669    return _time_format
def build_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
672def build_date_delta(
673    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
674) -> t.Callable[[t.List], E]:
675    def _builder(args: t.List) -> E:
676        unit_based = len(args) == 3
677        this = args[2] if unit_based else seq_get(args, 0)
678        unit = args[0] if unit_based else exp.Literal.string("DAY")
679        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
680        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
681
682    return _builder
def build_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
685def build_date_delta_with_interval(
686    expression_class: t.Type[E],
687) -> t.Callable[[t.List], t.Optional[E]]:
688    def _builder(args: t.List) -> t.Optional[E]:
689        if len(args) < 2:
690            return None
691
692        interval = args[1]
693
694        if not isinstance(interval, exp.Interval):
695            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
696
697        expression = interval.this
698        if expression and expression.is_string:
699            expression = exp.Literal.number(expression.this)
700
701        return expression_class(
702            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
703        )
704
705    return _builder
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
708def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
709    unit = seq_get(args, 0)
710    this = seq_get(args, 1)
711
712    if isinstance(this, exp.Cast) and this.is_type("date"):
713        return exp.DateTrunc(unit=unit, this=this)
714    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
717def date_add_interval_sql(
718    data_type: str, kind: str
719) -> t.Callable[[Generator, exp.Expression], str]:
720    def func(self: Generator, expression: exp.Expression) -> str:
721        this = self.sql(expression, "this")
722        unit = expression.args.get("unit")
723        unit = exp.var(unit.name.upper() if unit else "DAY")
724        interval = exp.Interval(this=expression.expression, unit=unit)
725        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
726
727    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
730def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
731    return self.func(
732        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
733    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
736def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
737    if not expression.expression:
738        from sqlglot.optimizer.annotate_types import annotate_types
739
740        target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
741        return self.sql(exp.cast(expression.this, to=target_type))
742    if expression.text("expression").lower() in TIMEZONES:
743        return self.sql(
744            exp.AtTimeZone(
745                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
746                zone=expression.expression,
747            )
748        )
749    return self.func("TIMESTAMP", expression.this, expression.expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
752def locate_to_strposition(args: t.List) -> exp.Expression:
753    return exp.StrPosition(
754        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
755    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
758def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
759    return self.func(
760        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
761    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
764def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
765    return self.sql(
766        exp.Substring(
767            this=expression.this, start=exp.Literal.number(1), length=expression.expression
768        )
769    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
772def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
773    return self.sql(
774        exp.Substring(
775            this=expression.this,
776            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
777        )
778    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
781def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
782    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
785def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
786    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
790def encode_decode_sql(
791    self: Generator, expression: exp.Expression, name: str, replace: bool = True
792) -> str:
793    charset = expression.args.get("charset")
794    if charset and charset.name.lower() != "utf-8":
795        self.unsupported(f"Expected utf-8 character set, got {charset}.")
796
797    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
800def min_or_least(self: Generator, expression: exp.Min) -> str:
801    name = "LEAST" if expression.expressions else "MIN"
802    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
805def max_or_greatest(self: Generator, expression: exp.Max) -> str:
806    name = "GREATEST" if expression.expressions else "MAX"
807    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
810def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
811    cond = expression.this
812
813    if isinstance(expression.this, exp.Distinct):
814        cond = expression.this.expressions[0]
815        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
816
817    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
820def trim_sql(self: Generator, expression: exp.Trim) -> str:
821    target = self.sql(expression, "this")
822    trim_type = self.sql(expression, "position")
823    remove_chars = self.sql(expression, "expression")
824    collation = self.sql(expression, "collation")
825
826    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
827    if not remove_chars and not collation:
828        return self.trim_sql(expression)
829
830    trim_type = f"{trim_type} " if trim_type else ""
831    remove_chars = f"{remove_chars} " if remove_chars else ""
832    from_part = "FROM " if trim_type or remove_chars else ""
833    collation = f" COLLATE {collation}" if collation else ""
834    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
837def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
838    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
841def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
842    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
845def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
846    delim, *rest_args = expression.expressions
847    return self.sql(
848        reduce(
849            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
850            rest_args,
851        )
852    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
855def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
856    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
857    if bad_args:
858        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
859
860    return self.func(
861        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
862    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
865def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
866    bad_args = list(
867        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
868    )
869    if bad_args:
870        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
871
872    return self.func(
873        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
874    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
877def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
878    names = []
879    for agg in aggregations:
880        if isinstance(agg, exp.Alias):
881            names.append(agg.alias)
882        else:
883            """
884            This case corresponds to aggregations without aliases being used as suffixes
885            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
886            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
887            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
888            """
889            agg_all_unquoted = agg.transform(
890                lambda node: (
891                    exp.Identifier(this=node.name, quoted=False)
892                    if isinstance(node, exp.Identifier)
893                    else node
894                )
895            )
896            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
897
898    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
901def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
902    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def build_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
906def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
907    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
910def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
911    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
914def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
915    a = self.sql(expression.left)
916    b = self.sql(expression.right)
917    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
920def is_parse_json(expression: exp.Expression) -> bool:
921    return isinstance(expression, exp.ParseJSON) or (
922        isinstance(expression, exp.Cast) and expression.is_type("json")
923    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
926def isnull_to_is_null(args: t.List) -> exp.Expression:
927    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
930def generatedasidentitycolumnconstraint_sql(
931    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
932) -> str:
933    start = self.sql(expression, "start") or "1"
934    increment = self.sql(expression, "increment") or "1"
935    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
938def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
939    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
940        if expression.args.get("count"):
941            self.unsupported(f"Only two arguments are supported in function {name}.")
942
943        return self.func(name, expression.this, expression.expression)
944
945    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
948def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
949    this = expression.this.copy()
950
951    return_type = expression.return_type
952    if return_type.is_type(exp.DataType.Type.DATE):
953        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
954        # can truncate timestamp strings, because some dialects can't cast them to DATE
955        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
956
957    expression.this.replace(exp.cast(this, return_type))
958    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
961def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
962    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
963        if cast and isinstance(expression, exp.TsOrDsAdd):
964            expression = ts_or_ds_add_cast(expression)
965
966        return self.func(
967            name,
968            exp.var(expression.text("unit").upper() or "DAY"),
969            expression.expression,
970            expression.this,
971        )
972
973    return _delta_sql
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
976def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
977    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
978    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
979    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
980
981    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
 984def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
 985    """Remove table refs from columns in when statements."""
 986    alias = expression.this.args.get("alias")
 987
 988    def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
 989        return self.dialect.normalize_identifier(identifier).name if identifier else None
 990
 991    targets = {normalize(expression.this.this)}
 992
 993    if alias:
 994        targets.add(normalize(alias.this))
 995
 996    for when in expression.expressions:
 997        when.transform(
 998            lambda node: (
 999                exp.column(node.this)
1000                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1001                else node
1002            ),
1003            copy=False,
1004        )
1005
1006    return self.merge_sql(expression)

Remove table refs from columns in when statements.

def build_json_extract_path( expr_type: Type[~F], zero_based_indexing: bool = True) -> Callable[[List], ~F]:
1009def build_json_extract_path(
1010    expr_type: t.Type[F], zero_based_indexing: bool = True
1011) -> t.Callable[[t.List], F]:
1012    def _builder(args: t.List) -> F:
1013        segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
1014        for arg in args[1:]:
1015            if not isinstance(arg, exp.Literal):
1016                # We use the fallback parser because we can't really transpile non-literals safely
1017                return expr_type.from_arg_list(args)
1018
1019            text = arg.name
1020            if is_int(text):
1021                index = int(text)
1022                segments.append(
1023                    exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
1024                )
1025            else:
1026                segments.append(exp.JSONPathKey(this=text))
1027
1028        # This is done to avoid failing in the expression validator due to the arg count
1029        del args[2:]
1030        return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
1031
1032    return _builder
def json_extract_segments( name: str, quoted_index: bool = True, op: Optional[str] = None) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.JSONExtract, sqlglot.expressions.JSONExtractScalar]], str]:
1035def json_extract_segments(
1036    name: str, quoted_index: bool = True, op: t.Optional[str] = None
1037) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
1038    def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
1039        path = expression.expression
1040        if not isinstance(path, exp.JSONPath):
1041            return rename_func(name)(self, expression)
1042
1043        segments = []
1044        for segment in path.expressions:
1045            path = self.sql(segment)
1046            if path:
1047                if isinstance(segment, exp.JSONPathPart) and (
1048                    quoted_index or not isinstance(segment, exp.JSONPathSubscript)
1049                ):
1050                    path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
1051
1052                segments.append(path)
1053
1054        if op:
1055            return f" {op} ".join([self.sql(expression.this), *segments])
1056        return self.func(name, expression.this, *segments)
1057
1058    return _json_extract_segments
def json_path_key_only_name( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONPathKey) -> str:
1061def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
1062    if isinstance(expression.this, exp.JSONPathWildcard):
1063        self.unsupported("Unsupported wildcard in JSONPathKey expression")
1064
1065    return expression.name
def filter_array_using_unnest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ArrayFilter) -> str:
1068def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
1069    cond = expression.expression
1070    if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
1071        alias = cond.expressions[0]
1072        cond = cond.this
1073    elif isinstance(cond, exp.Predicate):
1074        alias = "_u"
1075    else:
1076        self.unsupported("Unsupported filter condition")
1077        return ""
1078
1079    unnest = exp.Unnest(expressions=[expression.this])
1080    filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
1081    return self.sql(exp.Array(expressions=[filtered]))