Edit on GitHub

sqlglot.dialects.dialect

   1from __future__ import annotations
   2
   3import typing as t
   4from enum import Enum, auto
   5from functools import reduce
   6
   7from sqlglot import exp
   8from sqlglot.errors import ParseError
   9from sqlglot.generator import Generator
  10from sqlglot.helper import AutoName, flatten, seq_get
  11from sqlglot.parser import Parser
  12from sqlglot.time import TIMEZONES, format_time
  13from sqlglot.tokens import Token, Tokenizer, TokenType
  14from sqlglot.trie import new_trie
  15
  16DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
  17DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
  18
  19if t.TYPE_CHECKING:
  20    from sqlglot._typing import B, E
  21
  22
  23class Dialects(str, Enum):
  24    """Dialects supported by SQLGLot."""
  25
  26    DIALECT = ""
  27
  28    BIGQUERY = "bigquery"
  29    CLICKHOUSE = "clickhouse"
  30    DATABRICKS = "databricks"
  31    DORIS = "doris"
  32    DRILL = "drill"
  33    DUCKDB = "duckdb"
  34    HIVE = "hive"
  35    MYSQL = "mysql"
  36    ORACLE = "oracle"
  37    POSTGRES = "postgres"
  38    PRESTO = "presto"
  39    REDSHIFT = "redshift"
  40    SNOWFLAKE = "snowflake"
  41    SPARK = "spark"
  42    SPARK2 = "spark2"
  43    SQLITE = "sqlite"
  44    STARROCKS = "starrocks"
  45    TABLEAU = "tableau"
  46    TERADATA = "teradata"
  47    TRINO = "trino"
  48    TSQL = "tsql"
  49
  50
  51class NormalizationStrategy(str, AutoName):
  52    """Specifies the strategy according to which identifiers should be normalized."""
  53
  54    LOWERCASE = auto()
  55    """Unquoted identifiers are lowercased."""
  56
  57    UPPERCASE = auto()
  58    """Unquoted identifiers are uppercased."""
  59
  60    CASE_SENSITIVE = auto()
  61    """Always case-sensitive, regardless of quotes."""
  62
  63    CASE_INSENSITIVE = auto()
  64    """Always case-insensitive, regardless of quotes."""
  65
  66
  67class _Dialect(type):
  68    classes: t.Dict[str, t.Type[Dialect]] = {}
  69
  70    def __eq__(cls, other: t.Any) -> bool:
  71        if cls is other:
  72            return True
  73        if isinstance(other, str):
  74            return cls is cls.get(other)
  75        if isinstance(other, Dialect):
  76            return cls is type(other)
  77
  78        return False
  79
  80    def __hash__(cls) -> int:
  81        return hash(cls.__name__.lower())
  82
  83    @classmethod
  84    def __getitem__(cls, key: str) -> t.Type[Dialect]:
  85        return cls.classes[key]
  86
  87    @classmethod
  88    def get(
  89        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
  90    ) -> t.Optional[t.Type[Dialect]]:
  91        return cls.classes.get(key, default)
  92
  93    def __new__(cls, clsname, bases, attrs):
  94        klass = super().__new__(cls, clsname, bases, attrs)
  95        enum = Dialects.__members__.get(clsname.upper())
  96        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
  97
  98        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
  99        klass.FORMAT_TRIE = (
 100            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 101        )
 102        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 103        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 104
 105        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 106
 107        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 108        klass.parser_class = getattr(klass, "Parser", Parser)
 109        klass.generator_class = getattr(klass, "Generator", Generator)
 110
 111        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 112        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 113            klass.tokenizer_class._IDENTIFIERS.items()
 114        )[0]
 115
 116        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 117            return next(
 118                (
 119                    (s, e)
 120                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 121                    if t == token_type
 122                ),
 123                (None, None),
 124            )
 125
 126        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
 127        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
 128        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
 129        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
 130
 131        if enum not in ("", "bigquery"):
 132            klass.generator_class.SELECT_KINDS = ()
 133
 134        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
 135            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
 136                TokenType.ANTI,
 137                TokenType.SEMI,
 138            }
 139
 140        return klass
 141
 142
 143class Dialect(metaclass=_Dialect):
 144    INDEX_OFFSET = 0
 145    """Determines the base index offset for arrays."""
 146
 147    WEEK_OFFSET = 0
 148    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
 149
 150    UNNEST_COLUMN_ONLY = False
 151    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
 152
 153    ALIAS_POST_TABLESAMPLE = False
 154    """Determines whether or not the table alias comes after tablesample."""
 155
 156    TABLESAMPLE_SIZE_IS_PERCENT = False
 157    """Determines whether or not a size in the table sample clause represents percentage."""
 158
 159    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
 160    """Specifies the strategy according to which identifiers should be normalized."""
 161
 162    IDENTIFIERS_CAN_START_WITH_DIGIT = False
 163    """Determines whether or not an unquoted identifier can start with a digit."""
 164
 165    DPIPE_IS_STRING_CONCAT = True
 166    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
 167
 168    STRICT_STRING_CONCAT = False
 169    """Determines whether or not `CONCAT`'s arguments must be strings."""
 170
 171    SUPPORTS_USER_DEFINED_TYPES = True
 172    """Determines whether or not user-defined data types are supported."""
 173
 174    SUPPORTS_SEMI_ANTI_JOIN = True
 175    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
 176
 177    NORMALIZE_FUNCTIONS: bool | str = "upper"
 178    """Determines how function names are going to be normalized."""
 179
 180    LOG_BASE_FIRST = True
 181    """Determines whether the base comes first in the `LOG` function."""
 182
 183    NULL_ORDERING = "nulls_are_small"
 184    """
 185    Indicates the default `NULL` ordering method to use if not explicitly set.
 186    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
 187    """
 188
 189    TYPED_DIVISION = False
 190    """
 191    Whether the behavior of `a / b` depends on the types of `a` and `b`.
 192    False means `a / b` is always float division.
 193    True means `a / b` is integer division if both `a` and `b` are integers.
 194    """
 195
 196    SAFE_DIVISION = False
 197    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
 198
 199    CONCAT_COALESCE = False
 200    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
 201
 202    DATE_FORMAT = "'%Y-%m-%d'"
 203    DATEINT_FORMAT = "'%Y%m%d'"
 204    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
 205
 206    TIME_MAPPING: t.Dict[str, str] = {}
 207    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
 208
 209    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
 210    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
 211    FORMAT_MAPPING: t.Dict[str, str] = {}
 212    """
 213    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
 214    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
 215    """
 216
 217    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 218    """Mapping of an unescaped escape sequence to the corresponding character."""
 219
 220    PSEUDOCOLUMNS: t.Set[str] = set()
 221    """
 222    Columns that are auto-generated by the engine corresponding to this dialect.
 223    For example, such columns may be excluded from `SELECT *` queries.
 224    """
 225
 226    PREFER_CTE_ALIAS_COLUMN = False
 227    """
 228    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
 229    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
 230    any projection aliases in the subquery.
 231
 232    For example,
 233        WITH y(c) AS (
 234            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
 235        ) SELECT c FROM y;
 236
 237        will be rewritten as
 238
 239        WITH y(c) AS (
 240            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
 241        ) SELECT c FROM y;
 242    """
 243
 244    # --- Autofilled ---
 245
 246    tokenizer_class = Tokenizer
 247    parser_class = Parser
 248    generator_class = Generator
 249
 250    # A trie of the time_mapping keys
 251    TIME_TRIE: t.Dict = {}
 252    FORMAT_TRIE: t.Dict = {}
 253
 254    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
 255    INVERSE_TIME_TRIE: t.Dict = {}
 256
 257    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
 258
 259    # Delimiters for quotes, identifiers and the corresponding escape characters
 260    QUOTE_START = "'"
 261    QUOTE_END = "'"
 262    IDENTIFIER_START = '"'
 263    IDENTIFIER_END = '"'
 264
 265    # Delimiters for bit, hex, byte and unicode literals
 266    BIT_START: t.Optional[str] = None
 267    BIT_END: t.Optional[str] = None
 268    HEX_START: t.Optional[str] = None
 269    HEX_END: t.Optional[str] = None
 270    BYTE_START: t.Optional[str] = None
 271    BYTE_END: t.Optional[str] = None
 272    UNICODE_START: t.Optional[str] = None
 273    UNICODE_END: t.Optional[str] = None
 274
 275    @classmethod
 276    def get_or_raise(cls, dialect: DialectType) -> Dialect:
 277        """
 278        Look up a dialect in the global dialect registry and return it if it exists.
 279
 280        Args:
 281            dialect: The target dialect. If this is a string, it can be optionally followed by
 282                additional key-value pairs that are separated by commas and are used to specify
 283                dialect settings, such as whether the dialect's identifiers are case-sensitive.
 284
 285        Example:
 286            >>> dialect = dialect_class = get_or_raise("duckdb")
 287            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
 288
 289        Returns:
 290            The corresponding Dialect instance.
 291        """
 292
 293        if not dialect:
 294            return cls()
 295        if isinstance(dialect, _Dialect):
 296            return dialect()
 297        if isinstance(dialect, Dialect):
 298            return dialect
 299        if isinstance(dialect, str):
 300            try:
 301                dialect_name, *kv_pairs = dialect.split(",")
 302                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
 303            except ValueError:
 304                raise ValueError(
 305                    f"Invalid dialect format: '{dialect}'. "
 306                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
 307                )
 308
 309            result = cls.get(dialect_name.strip())
 310            if not result:
 311                from difflib import get_close_matches
 312
 313                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
 314                if similar:
 315                    similar = f" Did you mean {similar}?"
 316
 317                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
 318
 319            return result(**kwargs)
 320
 321        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
 322
 323    @classmethod
 324    def format_time(
 325        cls, expression: t.Optional[str | exp.Expression]
 326    ) -> t.Optional[exp.Expression]:
 327        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
 328        if isinstance(expression, str):
 329            return exp.Literal.string(
 330                # the time formats are quoted
 331                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
 332            )
 333
 334        if expression and expression.is_string:
 335            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
 336
 337        return expression
 338
 339    def __init__(self, **kwargs) -> None:
 340        normalization_strategy = kwargs.get("normalization_strategy")
 341
 342        if normalization_strategy is None:
 343            self.normalization_strategy = self.NORMALIZATION_STRATEGY
 344        else:
 345            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
 346
 347    def __eq__(self, other: t.Any) -> bool:
 348        # Does not currently take dialect state into account
 349        return type(self) == other
 350
 351    def __hash__(self) -> int:
 352        # Does not currently take dialect state into account
 353        return hash(type(self))
 354
 355    def normalize_identifier(self, expression: E) -> E:
 356        """
 357        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
 358
 359        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
 360        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
 361        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
 362        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
 363
 364        There are also dialects like Spark, which are case-insensitive even when quotes are
 365        present, and dialects like MySQL, whose resolution rules match those employed by the
 366        underlying operating system, for example they may always be case-sensitive in Linux.
 367
 368        Finally, the normalization behavior of some engines can even be controlled through flags,
 369        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
 370
 371        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
 372        that it can analyze queries in the optimizer and successfully capture their semantics.
 373        """
 374        if (
 375            isinstance(expression, exp.Identifier)
 376            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
 377            and (
 378                not expression.quoted
 379                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
 380            )
 381        ):
 382            expression.set(
 383                "this",
 384                (
 385                    expression.this.upper()
 386                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 387                    else expression.this.lower()
 388                ),
 389            )
 390
 391        return expression
 392
 393    def case_sensitive(self, text: str) -> bool:
 394        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
 395        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
 396            return False
 397
 398        unsafe = (
 399            str.islower
 400            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
 401            else str.isupper
 402        )
 403        return any(unsafe(char) for char in text)
 404
 405    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
 406        """Checks if text can be identified given an identify option.
 407
 408        Args:
 409            text: The text to check.
 410            identify:
 411                `"always"` or `True`: Always returns `True`.
 412                `"safe"`: Only returns `True` if the identifier is case-insensitive.
 413
 414        Returns:
 415            Whether or not the given text can be identified.
 416        """
 417        if identify is True or identify == "always":
 418            return True
 419
 420        if identify == "safe":
 421            return not self.case_sensitive(text)
 422
 423        return False
 424
 425    def quote_identifier(self, expression: E, identify: bool = True) -> E:
 426        """
 427        Adds quotes to a given identifier.
 428
 429        Args:
 430            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
 431            identify: If set to `False`, the quotes will only be added if the identifier is deemed
 432                "unsafe", with respect to its characters and this dialect's normalization strategy.
 433        """
 434        if isinstance(expression, exp.Identifier):
 435            name = expression.this
 436            expression.set(
 437                "quoted",
 438                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
 439            )
 440
 441        return expression
 442
 443    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
 444        return self.parser(**opts).parse(self.tokenize(sql), sql)
 445
 446    def parse_into(
 447        self, expression_type: exp.IntoType, sql: str, **opts
 448    ) -> t.List[t.Optional[exp.Expression]]:
 449        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
 450
 451    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
 452        return self.generator(**opts).generate(expression, copy=copy)
 453
 454    def transpile(self, sql: str, **opts) -> t.List[str]:
 455        return [
 456            self.generate(expression, copy=False, **opts) if expression else ""
 457            for expression in self.parse(sql)
 458        ]
 459
 460    def tokenize(self, sql: str) -> t.List[Token]:
 461        return self.tokenizer.tokenize(sql)
 462
 463    @property
 464    def tokenizer(self) -> Tokenizer:
 465        if not hasattr(self, "_tokenizer"):
 466            self._tokenizer = self.tokenizer_class(dialect=self)
 467        return self._tokenizer
 468
 469    def parser(self, **opts) -> Parser:
 470        return self.parser_class(dialect=self, **opts)
 471
 472    def generator(self, **opts) -> Generator:
 473        return self.generator_class(dialect=self, **opts)
 474
 475
 476DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
 477
 478
 479def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
 480    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
 481
 482
 483def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
 484    if expression.args.get("accuracy"):
 485        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
 486    return self.func("APPROX_COUNT_DISTINCT", expression.this)
 487
 488
 489def if_sql(
 490    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
 491) -> t.Callable[[Generator, exp.If], str]:
 492    def _if_sql(self: Generator, expression: exp.If) -> str:
 493        return self.func(
 494            name,
 495            expression.this,
 496            expression.args.get("true"),
 497            expression.args.get("false") or false_value,
 498        )
 499
 500    return _if_sql
 501
 502
 503def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
 504    return self.binary(expression, "->")
 505
 506
 507def arrow_json_extract_scalar_sql(
 508    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
 509) -> str:
 510    return self.binary(expression, "->>")
 511
 512
 513def inline_array_sql(self: Generator, expression: exp.Array) -> str:
 514    return f"[{self.expressions(expression, flat=True)}]"
 515
 516
 517def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
 518    return self.like_sql(
 519        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
 520    )
 521
 522
 523def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
 524    zone = self.sql(expression, "this")
 525    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
 526
 527
 528def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
 529    if expression.args.get("recursive"):
 530        self.unsupported("Recursive CTEs are unsupported")
 531        expression.args["recursive"] = False
 532    return self.with_sql(expression)
 533
 534
 535def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
 536    n = self.sql(expression, "this")
 537    d = self.sql(expression, "expression")
 538    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
 539
 540
 541def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
 542    self.unsupported("TABLESAMPLE unsupported")
 543    return self.sql(expression.this)
 544
 545
 546def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
 547    self.unsupported("PIVOT unsupported")
 548    return ""
 549
 550
 551def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
 552    return self.cast_sql(expression)
 553
 554
 555def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
 556    self.unsupported("Properties unsupported")
 557    return ""
 558
 559
 560def no_comment_column_constraint_sql(
 561    self: Generator, expression: exp.CommentColumnConstraint
 562) -> str:
 563    self.unsupported("CommentColumnConstraint unsupported")
 564    return ""
 565
 566
 567def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
 568    self.unsupported("MAP_FROM_ENTRIES unsupported")
 569    return ""
 570
 571
 572def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
 573    this = self.sql(expression, "this")
 574    substr = self.sql(expression, "substr")
 575    position = self.sql(expression, "position")
 576    if position:
 577        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
 578    return f"STRPOS({this}, {substr})"
 579
 580
 581def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
 582    return (
 583        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
 584    )
 585
 586
 587def var_map_sql(
 588    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
 589) -> str:
 590    keys = expression.args["keys"]
 591    values = expression.args["values"]
 592
 593    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
 594        self.unsupported("Cannot convert array columns into map.")
 595        return self.func(map_func_name, keys, values)
 596
 597    args = []
 598    for key, value in zip(keys.expressions, values.expressions):
 599        args.append(self.sql(key))
 600        args.append(self.sql(value))
 601
 602    return self.func(map_func_name, *args)
 603
 604
 605def format_time_lambda(
 606    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
 607) -> t.Callable[[t.List], E]:
 608    """Helper used for time expressions.
 609
 610    Args:
 611        exp_class: the expression class to instantiate.
 612        dialect: target sql dialect.
 613        default: the default format, True being time.
 614
 615    Returns:
 616        A callable that can be used to return the appropriately formatted time expression.
 617    """
 618
 619    def _format_time(args: t.List):
 620        return exp_class(
 621            this=seq_get(args, 0),
 622            format=Dialect[dialect].format_time(
 623                seq_get(args, 1)
 624                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
 625            ),
 626        )
 627
 628    return _format_time
 629
 630
 631def time_format(
 632    dialect: DialectType = None,
 633) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
 634    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
 635        """
 636        Returns the time format for a given expression, unless it's equivalent
 637        to the default time format of the dialect of interest.
 638        """
 639        time_format = self.format_time(expression)
 640        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
 641
 642    return _time_format
 643
 644
 645def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
 646    """
 647    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
 648    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
 649    columns are removed from the create statement.
 650    """
 651    has_schema = isinstance(expression.this, exp.Schema)
 652    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
 653
 654    if has_schema and is_partitionable:
 655        prop = expression.find(exp.PartitionedByProperty)
 656        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 657            schema = expression.this
 658            columns = {v.name.upper() for v in prop.this.expressions}
 659            partitions = [col for col in schema.expressions if col.name.upper() in columns]
 660            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
 661            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 662            expression.set("this", schema)
 663
 664    return self.create_sql(expression)
 665
 666
 667def parse_date_delta(
 668    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
 669) -> t.Callable[[t.List], E]:
 670    def inner_func(args: t.List) -> E:
 671        unit_based = len(args) == 3
 672        this = args[2] if unit_based else seq_get(args, 0)
 673        unit = args[0] if unit_based else exp.Literal.string("DAY")
 674        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
 675        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
 676
 677    return inner_func
 678
 679
 680def parse_date_delta_with_interval(
 681    expression_class: t.Type[E],
 682) -> t.Callable[[t.List], t.Optional[E]]:
 683    def func(args: t.List) -> t.Optional[E]:
 684        if len(args) < 2:
 685            return None
 686
 687        interval = args[1]
 688
 689        if not isinstance(interval, exp.Interval):
 690            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
 691
 692        expression = interval.this
 693        if expression and expression.is_string:
 694            expression = exp.Literal.number(expression.this)
 695
 696        return expression_class(
 697            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
 698        )
 699
 700    return func
 701
 702
 703def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
 704    unit = seq_get(args, 0)
 705    this = seq_get(args, 1)
 706
 707    if isinstance(this, exp.Cast) and this.is_type("date"):
 708        return exp.DateTrunc(unit=unit, this=this)
 709    return exp.TimestampTrunc(this=this, unit=unit)
 710
 711
 712def date_add_interval_sql(
 713    data_type: str, kind: str
 714) -> t.Callable[[Generator, exp.Expression], str]:
 715    def func(self: Generator, expression: exp.Expression) -> str:
 716        this = self.sql(expression, "this")
 717        unit = expression.args.get("unit")
 718        unit = exp.var(unit.name.upper() if unit else "DAY")
 719        interval = exp.Interval(this=expression.expression, unit=unit)
 720        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
 721
 722    return func
 723
 724
 725def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
 726    return self.func(
 727        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
 728    )
 729
 730
 731def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
 732    if not expression.expression:
 733        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
 734    if expression.text("expression").lower() in TIMEZONES:
 735        return self.sql(
 736            exp.AtTimeZone(
 737                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
 738                zone=expression.expression,
 739            )
 740        )
 741    return self.function_fallback_sql(expression)
 742
 743
 744def locate_to_strposition(args: t.List) -> exp.Expression:
 745    return exp.StrPosition(
 746        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
 747    )
 748
 749
 750def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
 751    return self.func(
 752        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
 753    )
 754
 755
 756def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 757    return self.sql(
 758        exp.Substring(
 759            this=expression.this, start=exp.Literal.number(1), length=expression.expression
 760        )
 761    )
 762
 763
 764def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
 765    return self.sql(
 766        exp.Substring(
 767            this=expression.this,
 768            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
 769        )
 770    )
 771
 772
 773def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
 774    return self.sql(exp.cast(expression.this, "timestamp"))
 775
 776
 777def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
 778    return self.sql(exp.cast(expression.this, "date"))
 779
 780
 781# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
 782def encode_decode_sql(
 783    self: Generator, expression: exp.Expression, name: str, replace: bool = True
 784) -> str:
 785    charset = expression.args.get("charset")
 786    if charset and charset.name.lower() != "utf-8":
 787        self.unsupported(f"Expected utf-8 character set, got {charset}.")
 788
 789    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
 790
 791
 792def min_or_least(self: Generator, expression: exp.Min) -> str:
 793    name = "LEAST" if expression.expressions else "MIN"
 794    return rename_func(name)(self, expression)
 795
 796
 797def max_or_greatest(self: Generator, expression: exp.Max) -> str:
 798    name = "GREATEST" if expression.expressions else "MAX"
 799    return rename_func(name)(self, expression)
 800
 801
 802def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
 803    cond = expression.this
 804
 805    if isinstance(expression.this, exp.Distinct):
 806        cond = expression.this.expressions[0]
 807        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
 808
 809    return self.func("sum", exp.func("if", cond, 1, 0))
 810
 811
 812def trim_sql(self: Generator, expression: exp.Trim) -> str:
 813    target = self.sql(expression, "this")
 814    trim_type = self.sql(expression, "position")
 815    remove_chars = self.sql(expression, "expression")
 816    collation = self.sql(expression, "collation")
 817
 818    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
 819    if not remove_chars and not collation:
 820        return self.trim_sql(expression)
 821
 822    trim_type = f"{trim_type} " if trim_type else ""
 823    remove_chars = f"{remove_chars} " if remove_chars else ""
 824    from_part = "FROM " if trim_type or remove_chars else ""
 825    collation = f" COLLATE {collation}" if collation else ""
 826    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
 827
 828
 829def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
 830    return self.func("STRPTIME", expression.this, self.format_time(expression))
 831
 832
 833def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
 834    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
 835
 836
 837def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
 838    delim, *rest_args = expression.expressions
 839    return self.sql(
 840        reduce(
 841            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
 842            rest_args,
 843        )
 844    )
 845
 846
 847def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
 848    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
 849    if bad_args:
 850        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
 851
 852    return self.func(
 853        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
 854    )
 855
 856
 857def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
 858    bad_args = list(
 859        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
 860    )
 861    if bad_args:
 862        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
 863
 864    return self.func(
 865        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
 866    )
 867
 868
 869def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
 870    names = []
 871    for agg in aggregations:
 872        if isinstance(agg, exp.Alias):
 873            names.append(agg.alias)
 874        else:
 875            """
 876            This case corresponds to aggregations without aliases being used as suffixes
 877            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
 878            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
 879            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
 880            """
 881            agg_all_unquoted = agg.transform(
 882                lambda node: (
 883                    exp.Identifier(this=node.name, quoted=False)
 884                    if isinstance(node, exp.Identifier)
 885                    else node
 886                )
 887            )
 888            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
 889
 890    return names
 891
 892
 893def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
 894    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
 895
 896
 897# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
 898def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
 899    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
 900
 901
 902def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
 903    return self.func("MAX", expression.this)
 904
 905
 906def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
 907    a = self.sql(expression.left)
 908    b = self.sql(expression.right)
 909    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
 910
 911
 912def is_parse_json(expression: exp.Expression) -> bool:
 913    return isinstance(expression, exp.ParseJSON) or (
 914        isinstance(expression, exp.Cast) and expression.is_type("json")
 915    )
 916
 917
 918def isnull_to_is_null(args: t.List) -> exp.Expression:
 919    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
 920
 921
 922def generatedasidentitycolumnconstraint_sql(
 923    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
 924) -> str:
 925    start = self.sql(expression, "start") or "1"
 926    increment = self.sql(expression, "increment") or "1"
 927    return f"IDENTITY({start}, {increment})"
 928
 929
 930def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
 931    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
 932        if expression.args.get("count"):
 933            self.unsupported(f"Only two arguments are supported in function {name}.")
 934
 935        return self.func(name, expression.this, expression.expression)
 936
 937    return _arg_max_or_min_sql
 938
 939
 940def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
 941    this = expression.this.copy()
 942
 943    return_type = expression.return_type
 944    if return_type.is_type(exp.DataType.Type.DATE):
 945        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
 946        # can truncate timestamp strings, because some dialects can't cast them to DATE
 947        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
 948
 949    expression.this.replace(exp.cast(this, return_type))
 950    return expression
 951
 952
 953def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
 954    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
 955        if cast and isinstance(expression, exp.TsOrDsAdd):
 956            expression = ts_or_ds_add_cast(expression)
 957
 958        return self.func(
 959            name,
 960            exp.var(expression.text("unit").upper() or "DAY"),
 961            expression.expression,
 962            expression.this,
 963        )
 964
 965    return _delta_sql
 966
 967
 968def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
 969    from sqlglot.optimizer.simplify import simplify
 970
 971    # Makes sure the path will be evaluated correctly at runtime to include the path root.
 972    # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
 973    path = expression.expression
 974    path = exp.func(
 975        "if",
 976        exp.func("startswith", path, "'['"),
 977        exp.func("concat", "'$'", path),
 978        exp.func("concat", "'$.'", path),
 979    )
 980
 981    expression.expression.replace(simplify(path))
 982    return expression
 983
 984
 985def path_to_jsonpath(
 986    name: str = "JSON_EXTRACT",
 987) -> t.Callable[[Generator, exp.GetPath], str]:
 988    def _transform(self: Generator, expression: exp.GetPath) -> str:
 989        return rename_func(name)(self, prepend_dollar_to_path(expression))
 990
 991    return _transform
 992
 993
 994def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 995    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 996    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 997    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 998
 999    return self.sql(exp.cast(minus_one_day, "date"))
1000
1001
1002def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1003    """Remove table refs from columns in when statements."""
1004    alias = expression.this.args.get("alias")
1005
1006    normalize = lambda identifier: (
1007        self.dialect.normalize_identifier(identifier).name if identifier else None
1008    )
1009
1010    targets = {normalize(expression.this.this)}
1011
1012    if alias:
1013        targets.add(normalize(alias.this))
1014
1015    for when in expression.expressions:
1016        when.transform(
1017            lambda node: (
1018                exp.column(node.this)
1019                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1020                else node
1021            ),
1022            copy=False,
1023        )
1024
1025    return self.merge_sql(expression)
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    """Dialects supported by SQLGLot."""
26
27    DIALECT = ""
28
29    BIGQUERY = "bigquery"
30    CLICKHOUSE = "clickhouse"
31    DATABRICKS = "databricks"
32    DORIS = "doris"
33    DRILL = "drill"
34    DUCKDB = "duckdb"
35    HIVE = "hive"
36    MYSQL = "mysql"
37    ORACLE = "oracle"
38    POSTGRES = "postgres"
39    PRESTO = "presto"
40    REDSHIFT = "redshift"
41    SNOWFLAKE = "snowflake"
42    SPARK = "spark"
43    SPARK2 = "spark2"
44    SQLITE = "sqlite"
45    STARROCKS = "starrocks"
46    TABLEAU = "tableau"
47    TERADATA = "teradata"
48    TRINO = "trino"
49    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
52class NormalizationStrategy(str, AutoName):
53    """Specifies the strategy according to which identifiers should be normalized."""
54
55    LOWERCASE = auto()
56    """Unquoted identifiers are lowercased."""
57
58    UPPERCASE = auto()
59    """Unquoted identifiers are uppercased."""
60
61    CASE_SENSITIVE = auto()
62    """Always case-sensitive, regardless of quotes."""
63
64    CASE_INSENSITIVE = auto()
65    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
144class Dialect(metaclass=_Dialect):
145    INDEX_OFFSET = 0
146    """Determines the base index offset for arrays."""
147
148    WEEK_OFFSET = 0
149    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
150
151    UNNEST_COLUMN_ONLY = False
152    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
153
154    ALIAS_POST_TABLESAMPLE = False
155    """Determines whether or not the table alias comes after tablesample."""
156
157    TABLESAMPLE_SIZE_IS_PERCENT = False
158    """Determines whether or not a size in the table sample clause represents percentage."""
159
160    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
161    """Specifies the strategy according to which identifiers should be normalized."""
162
163    IDENTIFIERS_CAN_START_WITH_DIGIT = False
164    """Determines whether or not an unquoted identifier can start with a digit."""
165
166    DPIPE_IS_STRING_CONCAT = True
167    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
168
169    STRICT_STRING_CONCAT = False
170    """Determines whether or not `CONCAT`'s arguments must be strings."""
171
172    SUPPORTS_USER_DEFINED_TYPES = True
173    """Determines whether or not user-defined data types are supported."""
174
175    SUPPORTS_SEMI_ANTI_JOIN = True
176    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
177
178    NORMALIZE_FUNCTIONS: bool | str = "upper"
179    """Determines how function names are going to be normalized."""
180
181    LOG_BASE_FIRST = True
182    """Determines whether the base comes first in the `LOG` function."""
183
184    NULL_ORDERING = "nulls_are_small"
185    """
186    Indicates the default `NULL` ordering method to use if not explicitly set.
187    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
188    """
189
190    TYPED_DIVISION = False
191    """
192    Whether the behavior of `a / b` depends on the types of `a` and `b`.
193    False means `a / b` is always float division.
194    True means `a / b` is integer division if both `a` and `b` are integers.
195    """
196
197    SAFE_DIVISION = False
198    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
199
200    CONCAT_COALESCE = False
201    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
202
203    DATE_FORMAT = "'%Y-%m-%d'"
204    DATEINT_FORMAT = "'%Y%m%d'"
205    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
206
207    TIME_MAPPING: t.Dict[str, str] = {}
208    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
209
210    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
211    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
212    FORMAT_MAPPING: t.Dict[str, str] = {}
213    """
214    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
215    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
216    """
217
218    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
219    """Mapping of an unescaped escape sequence to the corresponding character."""
220
221    PSEUDOCOLUMNS: t.Set[str] = set()
222    """
223    Columns that are auto-generated by the engine corresponding to this dialect.
224    For example, such columns may be excluded from `SELECT *` queries.
225    """
226
227    PREFER_CTE_ALIAS_COLUMN = False
228    """
229    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
230    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
231    any projection aliases in the subquery.
232
233    For example,
234        WITH y(c) AS (
235            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
236        ) SELECT c FROM y;
237
238        will be rewritten as
239
240        WITH y(c) AS (
241            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
242        ) SELECT c FROM y;
243    """
244
245    # --- Autofilled ---
246
247    tokenizer_class = Tokenizer
248    parser_class = Parser
249    generator_class = Generator
250
251    # A trie of the time_mapping keys
252    TIME_TRIE: t.Dict = {}
253    FORMAT_TRIE: t.Dict = {}
254
255    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
256    INVERSE_TIME_TRIE: t.Dict = {}
257
258    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
259
260    # Delimiters for quotes, identifiers and the corresponding escape characters
261    QUOTE_START = "'"
262    QUOTE_END = "'"
263    IDENTIFIER_START = '"'
264    IDENTIFIER_END = '"'
265
266    # Delimiters for bit, hex, byte and unicode literals
267    BIT_START: t.Optional[str] = None
268    BIT_END: t.Optional[str] = None
269    HEX_START: t.Optional[str] = None
270    HEX_END: t.Optional[str] = None
271    BYTE_START: t.Optional[str] = None
272    BYTE_END: t.Optional[str] = None
273    UNICODE_START: t.Optional[str] = None
274    UNICODE_END: t.Optional[str] = None
275
276    @classmethod
277    def get_or_raise(cls, dialect: DialectType) -> Dialect:
278        """
279        Look up a dialect in the global dialect registry and return it if it exists.
280
281        Args:
282            dialect: The target dialect. If this is a string, it can be optionally followed by
283                additional key-value pairs that are separated by commas and are used to specify
284                dialect settings, such as whether the dialect's identifiers are case-sensitive.
285
286        Example:
287            >>> dialect = dialect_class = get_or_raise("duckdb")
288            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
289
290        Returns:
291            The corresponding Dialect instance.
292        """
293
294        if not dialect:
295            return cls()
296        if isinstance(dialect, _Dialect):
297            return dialect()
298        if isinstance(dialect, Dialect):
299            return dialect
300        if isinstance(dialect, str):
301            try:
302                dialect_name, *kv_pairs = dialect.split(",")
303                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
304            except ValueError:
305                raise ValueError(
306                    f"Invalid dialect format: '{dialect}'. "
307                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
308                )
309
310            result = cls.get(dialect_name.strip())
311            if not result:
312                from difflib import get_close_matches
313
314                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
315                if similar:
316                    similar = f" Did you mean {similar}?"
317
318                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
319
320            return result(**kwargs)
321
322        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
323
324    @classmethod
325    def format_time(
326        cls, expression: t.Optional[str | exp.Expression]
327    ) -> t.Optional[exp.Expression]:
328        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
329        if isinstance(expression, str):
330            return exp.Literal.string(
331                # the time formats are quoted
332                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
333            )
334
335        if expression and expression.is_string:
336            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
337
338        return expression
339
340    def __init__(self, **kwargs) -> None:
341        normalization_strategy = kwargs.get("normalization_strategy")
342
343        if normalization_strategy is None:
344            self.normalization_strategy = self.NORMALIZATION_STRATEGY
345        else:
346            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
347
348    def __eq__(self, other: t.Any) -> bool:
349        # Does not currently take dialect state into account
350        return type(self) == other
351
352    def __hash__(self) -> int:
353        # Does not currently take dialect state into account
354        return hash(type(self))
355
356    def normalize_identifier(self, expression: E) -> E:
357        """
358        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
359
360        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
361        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
362        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
363        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
364
365        There are also dialects like Spark, which are case-insensitive even when quotes are
366        present, and dialects like MySQL, whose resolution rules match those employed by the
367        underlying operating system, for example they may always be case-sensitive in Linux.
368
369        Finally, the normalization behavior of some engines can even be controlled through flags,
370        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
371
372        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
373        that it can analyze queries in the optimizer and successfully capture their semantics.
374        """
375        if (
376            isinstance(expression, exp.Identifier)
377            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
378            and (
379                not expression.quoted
380                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
381            )
382        ):
383            expression.set(
384                "this",
385                (
386                    expression.this.upper()
387                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
388                    else expression.this.lower()
389                ),
390            )
391
392        return expression
393
394    def case_sensitive(self, text: str) -> bool:
395        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
396        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
397            return False
398
399        unsafe = (
400            str.islower
401            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
402            else str.isupper
403        )
404        return any(unsafe(char) for char in text)
405
406    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
407        """Checks if text can be identified given an identify option.
408
409        Args:
410            text: The text to check.
411            identify:
412                `"always"` or `True`: Always returns `True`.
413                `"safe"`: Only returns `True` if the identifier is case-insensitive.
414
415        Returns:
416            Whether or not the given text can be identified.
417        """
418        if identify is True or identify == "always":
419            return True
420
421        if identify == "safe":
422            return not self.case_sensitive(text)
423
424        return False
425
426    def quote_identifier(self, expression: E, identify: bool = True) -> E:
427        """
428        Adds quotes to a given identifier.
429
430        Args:
431            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
432            identify: If set to `False`, the quotes will only be added if the identifier is deemed
433                "unsafe", with respect to its characters and this dialect's normalization strategy.
434        """
435        if isinstance(expression, exp.Identifier):
436            name = expression.this
437            expression.set(
438                "quoted",
439                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
440            )
441
442        return expression
443
444    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
445        return self.parser(**opts).parse(self.tokenize(sql), sql)
446
447    def parse_into(
448        self, expression_type: exp.IntoType, sql: str, **opts
449    ) -> t.List[t.Optional[exp.Expression]]:
450        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
451
452    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
453        return self.generator(**opts).generate(expression, copy=copy)
454
455    def transpile(self, sql: str, **opts) -> t.List[str]:
456        return [
457            self.generate(expression, copy=False, **opts) if expression else ""
458            for expression in self.parse(sql)
459        ]
460
461    def tokenize(self, sql: str) -> t.List[Token]:
462        return self.tokenizer.tokenize(sql)
463
464    @property
465    def tokenizer(self) -> Tokenizer:
466        if not hasattr(self, "_tokenizer"):
467            self._tokenizer = self.tokenizer_class(dialect=self)
468        return self._tokenizer
469
470    def parser(self, **opts) -> Parser:
471        return self.parser_class(dialect=self, **opts)
472
473    def generator(self, **opts) -> Generator:
474        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
340    def __init__(self, **kwargs) -> None:
341        normalization_strategy = kwargs.get("normalization_strategy")
342
343        if normalization_strategy is None:
344            self.normalization_strategy = self.NORMALIZATION_STRATEGY
345        else:
346            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

Determines the base index offset for arrays.

WEEK_OFFSET = 0

Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Determines whether or not UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Determines whether or not the table alias comes after tablesample.

TABLESAMPLE_SIZE_IS_PERCENT = False

Determines whether or not a size in the table sample clause represents percentage.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Determines whether or not an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Determines whether or not the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Determines whether or not CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Determines whether or not user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Determines whether or not SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

LOG_BASE_FIRST = True

Determines whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Indicates the default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Determines whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime format.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
276    @classmethod
277    def get_or_raise(cls, dialect: DialectType) -> Dialect:
278        """
279        Look up a dialect in the global dialect registry and return it if it exists.
280
281        Args:
282            dialect: The target dialect. If this is a string, it can be optionally followed by
283                additional key-value pairs that are separated by commas and are used to specify
284                dialect settings, such as whether the dialect's identifiers are case-sensitive.
285
286        Example:
287            >>> dialect = dialect_class = get_or_raise("duckdb")
288            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
289
290        Returns:
291            The corresponding Dialect instance.
292        """
293
294        if not dialect:
295            return cls()
296        if isinstance(dialect, _Dialect):
297            return dialect()
298        if isinstance(dialect, Dialect):
299            return dialect
300        if isinstance(dialect, str):
301            try:
302                dialect_name, *kv_pairs = dialect.split(",")
303                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
304            except ValueError:
305                raise ValueError(
306                    f"Invalid dialect format: '{dialect}'. "
307                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
308                )
309
310            result = cls.get(dialect_name.strip())
311            if not result:
312                from difflib import get_close_matches
313
314                similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
315                if similar:
316                    similar = f" Did you mean {similar}?"
317
318                raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
319
320            return result(**kwargs)
321
322        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
324    @classmethod
325    def format_time(
326        cls, expression: t.Optional[str | exp.Expression]
327    ) -> t.Optional[exp.Expression]:
328        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
329        if isinstance(expression, str):
330            return exp.Literal.string(
331                # the time formats are quoted
332                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
333            )
334
335        if expression and expression.is_string:
336            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
337
338        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
356    def normalize_identifier(self, expression: E) -> E:
357        """
358        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
359
360        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
361        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
362        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
363        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
364
365        There are also dialects like Spark, which are case-insensitive even when quotes are
366        present, and dialects like MySQL, whose resolution rules match those employed by the
367        underlying operating system, for example they may always be case-sensitive in Linux.
368
369        Finally, the normalization behavior of some engines can even be controlled through flags,
370        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
371
372        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
373        that it can analyze queries in the optimizer and successfully capture their semantics.
374        """
375        if (
376            isinstance(expression, exp.Identifier)
377            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
378            and (
379                not expression.quoted
380                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
381            )
382        ):
383            expression.set(
384                "this",
385                (
386                    expression.this.upper()
387                    if self.normalization_strategy is NormalizationStrategy.UPPERCASE
388                    else expression.this.lower()
389                ),
390            )
391
392        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
394    def case_sensitive(self, text: str) -> bool:
395        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
396        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
397            return False
398
399        unsafe = (
400            str.islower
401            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
402            else str.isupper
403        )
404        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
406    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
407        """Checks if text can be identified given an identify option.
408
409        Args:
410            text: The text to check.
411            identify:
412                `"always"` or `True`: Always returns `True`.
413                `"safe"`: Only returns `True` if the identifier is case-insensitive.
414
415        Returns:
416            Whether or not the given text can be identified.
417        """
418        if identify is True or identify == "always":
419            return True
420
421        if identify == "safe":
422            return not self.case_sensitive(text)
423
424        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
426    def quote_identifier(self, expression: E, identify: bool = True) -> E:
427        """
428        Adds quotes to a given identifier.
429
430        Args:
431            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
432            identify: If set to `False`, the quotes will only be added if the identifier is deemed
433                "unsafe", with respect to its characters and this dialect's normalization strategy.
434        """
435        if isinstance(expression, exp.Identifier):
436            name = expression.this
437            expression.set(
438                "quoted",
439                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
440            )
441
442        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
444    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
445        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
447    def parse_into(
448        self, expression_type: exp.IntoType, sql: str, **opts
449    ) -> t.List[t.Optional[exp.Expression]]:
450        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
452    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
453        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
455    def transpile(self, sql: str, **opts) -> t.List[str]:
456        return [
457            self.generate(expression, copy=False, **opts) if expression else ""
458            for expression in self.parse(sql)
459        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
461    def tokenize(self, sql: str) -> t.List[Token]:
462        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
464    @property
465    def tokenizer(self) -> Tokenizer:
466        if not hasattr(self, "_tokenizer"):
467            self._tokenizer = self.tokenizer_class(dialect=self)
468        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
470    def parser(self, **opts) -> Parser:
471        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
473    def generator(self, **opts) -> Generator:
474        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
480def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
481    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
484def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
485    if expression.args.get("accuracy"):
486        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
487    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
490def if_sql(
491    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
492) -> t.Callable[[Generator, exp.If], str]:
493    def _if_sql(self: Generator, expression: exp.If) -> str:
494        return self.func(
495            name,
496            expression.this,
497            expression.args.get("true"),
498            expression.args.get("false") or false_value,
499        )
500
501    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
504def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
505    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
508def arrow_json_extract_scalar_sql(
509    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
510) -> str:
511    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
514def inline_array_sql(self: Generator, expression: exp.Array) -> str:
515    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
518def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
519    return self.like_sql(
520        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
521    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
524def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
525    zone = self.sql(expression, "this")
526    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
529def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
530    if expression.args.get("recursive"):
531        self.unsupported("Recursive CTEs are unsupported")
532        expression.args["recursive"] = False
533    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
536def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
537    n = self.sql(expression, "this")
538    d = self.sql(expression, "expression")
539    return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
542def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
543    self.unsupported("TABLESAMPLE unsupported")
544    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
547def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
548    self.unsupported("PIVOT unsupported")
549    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
552def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
553    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
556def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
557    self.unsupported("Properties unsupported")
558    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
561def no_comment_column_constraint_sql(
562    self: Generator, expression: exp.CommentColumnConstraint
563) -> str:
564    self.unsupported("CommentColumnConstraint unsupported")
565    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
568def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
569    self.unsupported("MAP_FROM_ENTRIES unsupported")
570    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
573def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
574    this = self.sql(expression, "this")
575    substr = self.sql(expression, "substr")
576    position = self.sql(expression, "position")
577    if position:
578        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
579    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
582def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
583    return (
584        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
585    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
588def var_map_sql(
589    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
590) -> str:
591    keys = expression.args["keys"]
592    values = expression.args["values"]
593
594    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
595        self.unsupported("Cannot convert array columns into map.")
596        return self.func(map_func_name, keys, values)
597
598    args = []
599    for key, value in zip(keys.expressions, values.expressions):
600        args.append(self.sql(key))
601        args.append(self.sql(value))
602
603    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
606def format_time_lambda(
607    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
608) -> t.Callable[[t.List], E]:
609    """Helper used for time expressions.
610
611    Args:
612        exp_class: the expression class to instantiate.
613        dialect: target sql dialect.
614        default: the default format, True being time.
615
616    Returns:
617        A callable that can be used to return the appropriately formatted time expression.
618    """
619
620    def _format_time(args: t.List):
621        return exp_class(
622            this=seq_get(args, 0),
623            format=Dialect[dialect].format_time(
624                seq_get(args, 1)
625                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
626            ),
627        )
628
629    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
632def time_format(
633    dialect: DialectType = None,
634) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
635    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
636        """
637        Returns the time format for a given expression, unless it's equivalent
638        to the default time format of the dialect of interest.
639        """
640        time_format = self.format_time(expression)
641        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
642
643    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
646def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
647    """
648    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
649    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
650    columns are removed from the create statement.
651    """
652    has_schema = isinstance(expression.this, exp.Schema)
653    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
654
655    if has_schema and is_partitionable:
656        prop = expression.find(exp.PartitionedByProperty)
657        if prop and prop.this and not isinstance(prop.this, exp.Schema):
658            schema = expression.this
659            columns = {v.name.upper() for v in prop.this.expressions}
660            partitions = [col for col in schema.expressions if col.name.upper() in columns]
661            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
662            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
663            expression.set("this", schema)
664
665    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
668def parse_date_delta(
669    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
670) -> t.Callable[[t.List], E]:
671    def inner_func(args: t.List) -> E:
672        unit_based = len(args) == 3
673        this = args[2] if unit_based else seq_get(args, 0)
674        unit = args[0] if unit_based else exp.Literal.string("DAY")
675        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
676        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
677
678    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
681def parse_date_delta_with_interval(
682    expression_class: t.Type[E],
683) -> t.Callable[[t.List], t.Optional[E]]:
684    def func(args: t.List) -> t.Optional[E]:
685        if len(args) < 2:
686            return None
687
688        interval = args[1]
689
690        if not isinstance(interval, exp.Interval):
691            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
692
693        expression = interval.this
694        if expression and expression.is_string:
695            expression = exp.Literal.number(expression.this)
696
697        return expression_class(
698            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
699        )
700
701    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
704def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
705    unit = seq_get(args, 0)
706    this = seq_get(args, 1)
707
708    if isinstance(this, exp.Cast) and this.is_type("date"):
709        return exp.DateTrunc(unit=unit, this=this)
710    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
713def date_add_interval_sql(
714    data_type: str, kind: str
715) -> t.Callable[[Generator, exp.Expression], str]:
716    def func(self: Generator, expression: exp.Expression) -> str:
717        this = self.sql(expression, "this")
718        unit = expression.args.get("unit")
719        unit = exp.var(unit.name.upper() if unit else "DAY")
720        interval = exp.Interval(this=expression.expression, unit=unit)
721        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
722
723    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
726def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
727    return self.func(
728        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
729    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
732def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
733    if not expression.expression:
734        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
735    if expression.text("expression").lower() in TIMEZONES:
736        return self.sql(
737            exp.AtTimeZone(
738                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
739                zone=expression.expression,
740            )
741        )
742    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
745def locate_to_strposition(args: t.List) -> exp.Expression:
746    return exp.StrPosition(
747        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
748    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
751def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
752    return self.func(
753        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
754    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
757def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
758    return self.sql(
759        exp.Substring(
760            this=expression.this, start=exp.Literal.number(1), length=expression.expression
761        )
762    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
765def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
766    return self.sql(
767        exp.Substring(
768            this=expression.this,
769            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
770        )
771    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
774def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
775    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
778def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
779    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
783def encode_decode_sql(
784    self: Generator, expression: exp.Expression, name: str, replace: bool = True
785) -> str:
786    charset = expression.args.get("charset")
787    if charset and charset.name.lower() != "utf-8":
788        self.unsupported(f"Expected utf-8 character set, got {charset}.")
789
790    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
793def min_or_least(self: Generator, expression: exp.Min) -> str:
794    name = "LEAST" if expression.expressions else "MIN"
795    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
798def max_or_greatest(self: Generator, expression: exp.Max) -> str:
799    name = "GREATEST" if expression.expressions else "MAX"
800    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
803def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
804    cond = expression.this
805
806    if isinstance(expression.this, exp.Distinct):
807        cond = expression.this.expressions[0]
808        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
809
810    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
813def trim_sql(self: Generator, expression: exp.Trim) -> str:
814    target = self.sql(expression, "this")
815    trim_type = self.sql(expression, "position")
816    remove_chars = self.sql(expression, "expression")
817    collation = self.sql(expression, "collation")
818
819    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
820    if not remove_chars and not collation:
821        return self.trim_sql(expression)
822
823    trim_type = f"{trim_type} " if trim_type else ""
824    remove_chars = f"{remove_chars} " if remove_chars else ""
825    from_part = "FROM " if trim_type or remove_chars else ""
826    collation = f" COLLATE {collation}" if collation else ""
827    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
830def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
831    return self.func("STRPTIME", expression.this, self.format_time(expression))
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
834def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
835    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
838def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
839    delim, *rest_args = expression.expressions
840    return self.sql(
841        reduce(
842            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
843            rest_args,
844        )
845    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
848def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
849    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
850    if bad_args:
851        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
852
853    return self.func(
854        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
855    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
858def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
859    bad_args = list(
860        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
861    )
862    if bad_args:
863        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
864
865    return self.func(
866        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
867    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
870def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
871    names = []
872    for agg in aggregations:
873        if isinstance(agg, exp.Alias):
874            names.append(agg.alias)
875        else:
876            """
877            This case corresponds to aggregations without aliases being used as suffixes
878            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
879            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
880            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
881            """
882            agg_all_unquoted = agg.transform(
883                lambda node: (
884                    exp.Identifier(this=node.name, quoted=False)
885                    if isinstance(node, exp.Identifier)
886                    else node
887                )
888            )
889            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
890
891    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
894def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
895    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
899def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
900    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
903def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
904    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
907def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
908    a = self.sql(expression.left)
909    b = self.sql(expression.right)
910    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
913def is_parse_json(expression: exp.Expression) -> bool:
914    return isinstance(expression, exp.ParseJSON) or (
915        isinstance(expression, exp.Cast) and expression.is_type("json")
916    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
919def isnull_to_is_null(args: t.List) -> exp.Expression:
920    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
923def generatedasidentitycolumnconstraint_sql(
924    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
925) -> str:
926    start = self.sql(expression, "start") or "1"
927    increment = self.sql(expression, "increment") or "1"
928    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
931def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
932    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
933        if expression.args.get("count"):
934            self.unsupported(f"Only two arguments are supported in function {name}.")
935
936        return self.func(name, expression.this, expression.expression)
937
938    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
941def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
942    this = expression.this.copy()
943
944    return_type = expression.return_type
945    if return_type.is_type(exp.DataType.Type.DATE):
946        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
947        # can truncate timestamp strings, because some dialects can't cast them to DATE
948        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
949
950    expression.this.replace(exp.cast(this, return_type))
951    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
954def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
955    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
956        if cast and isinstance(expression, exp.TsOrDsAdd):
957            expression = ts_or_ds_add_cast(expression)
958
959        return self.func(
960            name,
961            exp.var(expression.text("unit").upper() or "DAY"),
962            expression.expression,
963            expression.this,
964        )
965
966    return _delta_sql
def prepend_dollar_to_path(expression: sqlglot.expressions.GetPath) -> sqlglot.expressions.GetPath:
969def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
970    from sqlglot.optimizer.simplify import simplify
971
972    # Makes sure the path will be evaluated correctly at runtime to include the path root.
973    # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
974    path = expression.expression
975    path = exp.func(
976        "if",
977        exp.func("startswith", path, "'['"),
978        exp.func("concat", "'$'", path),
979        exp.func("concat", "'$.'", path),
980    )
981
982    expression.expression.replace(simplify(path))
983    return expression
def path_to_jsonpath( name: str = 'JSON_EXTRACT') -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.GetPath], str]:
986def path_to_jsonpath(
987    name: str = "JSON_EXTRACT",
988) -> t.Callable[[Generator, exp.GetPath], str]:
989    def _transform(self: Generator, expression: exp.GetPath) -> str:
990        return rename_func(name)(self, prepend_dollar_to_path(expression))
991
992    return _transform
def no_last_day_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.LastDay) -> str:
 995def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
 996    trunc_curr_date = exp.func("date_trunc", "month", expression.this)
 997    plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
 998    minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
 999
1000    return self.sql(exp.cast(minus_one_day, "date"))
def merge_without_target_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Merge) -> str:
1003def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
1004    """Remove table refs from columns in when statements."""
1005    alias = expression.this.args.get("alias")
1006
1007    normalize = lambda identifier: (
1008        self.dialect.normalize_identifier(identifier).name if identifier else None
1009    )
1010
1011    targets = {normalize(expression.this.this)}
1012
1013    if alias:
1014        targets.add(normalize(alias.this))
1015
1016    for when in expression.expressions:
1017        when.transform(
1018            lambda node: (
1019                exp.column(node.this)
1020                if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1021                else node
1022            ),
1023            copy=False,
1024        )
1025
1026    return self.merge_sql(expression)

Remove table refs from columns in when statements.