Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import TIMEZONES, format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
 21
 22
 23class Dialects(str, Enum):
 24    DIALECT = ""
 25
 26    BIGQUERY = "bigquery"
 27    CLICKHOUSE = "clickhouse"
 28    DATABRICKS = "databricks"
 29    DRILL = "drill"
 30    DUCKDB = "duckdb"
 31    HIVE = "hive"
 32    MYSQL = "mysql"
 33    ORACLE = "oracle"
 34    POSTGRES = "postgres"
 35    PRESTO = "presto"
 36    REDSHIFT = "redshift"
 37    SNOWFLAKE = "snowflake"
 38    SPARK = "spark"
 39    SPARK2 = "spark2"
 40    SQLITE = "sqlite"
 41    STARROCKS = "starrocks"
 42    TABLEAU = "tableau"
 43    TERADATA = "teradata"
 44    TRINO = "trino"
 45    TSQL = "tsql"
 46    Doris = "doris"
 47
 48
 49class _Dialect(type):
 50    classes: t.Dict[str, t.Type[Dialect]] = {}
 51
 52    def __eq__(cls, other: t.Any) -> bool:
 53        if cls is other:
 54            return True
 55        if isinstance(other, str):
 56            return cls is cls.get(other)
 57        if isinstance(other, Dialect):
 58            return cls is type(other)
 59
 60        return False
 61
 62    def __hash__(cls) -> int:
 63        return hash(cls.__name__.lower())
 64
 65    @classmethod
 66    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 67        return cls.classes[key]
 68
 69    @classmethod
 70    def get(
 71        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 72    ) -> t.Optional[t.Type[Dialect]]:
 73        return cls.classes.get(key, default)
 74
 75    def __new__(cls, clsname, bases, attrs):
 76        klass = super().__new__(cls, clsname, bases, attrs)
 77        enum = Dialects.__members__.get(clsname.upper())
 78        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 79
 80        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 81        klass.FORMAT_TRIE = (
 82            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 83        )
 84        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 85        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 86
 87        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
 88
 89        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 90        klass.parser_class = getattr(klass, "Parser", Parser)
 91        klass.generator_class = getattr(klass, "Generator", Generator)
 92
 93        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 94        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 95            klass.tokenizer_class._IDENTIFIERS.items()
 96        )[0]
 97
 98        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 99            return next(
100                (
101                    (s, e)
102                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
103                    if t == token_type
104                ),
105                (None, None),
106            )
107
108        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
109        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
110        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
111
112        dialect_properties = {
113            **{
114                k: v
115                for k, v in vars(klass).items()
116                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
117            },
118            "TOKENIZER_CLASS": klass.tokenizer_class,
119        }
120
121        if enum not in ("", "bigquery"):
122            dialect_properties["SELECT_KINDS"] = ()
123
124        # Pass required dialect properties to the tokenizer, parser and generator classes
125        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
126            for name, value in dialect_properties.items():
127                if hasattr(subclass, name):
128                    setattr(subclass, name, value)
129
130        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
131            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
132                TokenType.ANTI,
133                TokenType.SEMI,
134            }
135
136        klass.generator_class.can_identify = klass.can_identify
137
138        return klass
139
140
141class Dialect(metaclass=_Dialect):
142    # Determines the base index offset for arrays
143    INDEX_OFFSET = 0
144
145    # If true unnest table aliases are considered only as column aliases
146    UNNEST_COLUMN_ONLY = False
147
148    # Determines whether or not the table alias comes after tablesample
149    ALIAS_POST_TABLESAMPLE = False
150
151    # Determines whether or not unquoted identifiers are resolved as uppercase
152    # When set to None, it means that the dialect treats all identifiers as case-insensitive
153    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
154
155    # Determines whether or not an unquoted identifier can start with a digit
156    IDENTIFIERS_CAN_START_WITH_DIGIT = False
157
158    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
159    DPIPE_IS_STRING_CONCAT = True
160
161    # Determines whether or not CONCAT's arguments must be strings
162    STRICT_STRING_CONCAT = False
163
164    # Determines whether or not user-defined data types are supported
165    SUPPORTS_USER_DEFINED_TYPES = True
166
167    # Determines whether or not SEMI/ANTI JOINs are supported
168    SUPPORTS_SEMI_ANTI_JOIN = True
169
170    # Determines how function names are going to be normalized
171    NORMALIZE_FUNCTIONS: bool | str = "upper"
172
173    # Determines whether the base comes first in the LOG function
174    LOG_BASE_FIRST = True
175
176    # Indicates the default null ordering method to use if not explicitly set
177    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
178    NULL_ORDERING = "nulls_are_small"
179
180    # Whether the behavior of a / b depends on the types of a and b.
181    # False means a / b is always float division.
182    # True means a / b is integer division if both a and b are integers.
183    TYPED_DIVISION = False
184
185    # False means 1 / 0 throws an error.
186    # True means 1 / 0 returns null.
187    SAFE_DIVISION = False
188
189    DATE_FORMAT = "'%Y-%m-%d'"
190    DATEINT_FORMAT = "'%Y%m%d'"
191    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
192
193    # Custom time mappings in which the key represents dialect time format
194    # and the value represents a python time format
195    TIME_MAPPING: t.Dict[str, str] = {}
196
197    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
198    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
199    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
200    FORMAT_MAPPING: t.Dict[str, str] = {}
201
202    # Mapping of an unescaped escape sequence to the corresponding character
203    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
204
205    # Columns that are auto-generated by the engine corresponding to this dialect
206    # Such columns may be excluded from SELECT * queries, for example
207    PSEUDOCOLUMNS: t.Set[str] = set()
208
209    # Autofilled
210    tokenizer_class = Tokenizer
211    parser_class = Parser
212    generator_class = Generator
213
214    # A trie of the time_mapping keys
215    TIME_TRIE: t.Dict = {}
216    FORMAT_TRIE: t.Dict = {}
217
218    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
219    INVERSE_TIME_TRIE: t.Dict = {}
220
221    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
222
223    def __eq__(self, other: t.Any) -> bool:
224        return type(self) == other
225
226    def __hash__(self) -> int:
227        return hash(type(self))
228
229    @classmethod
230    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
231        if not dialect:
232            return cls
233        if isinstance(dialect, _Dialect):
234            return dialect
235        if isinstance(dialect, Dialect):
236            return dialect.__class__
237
238        result = cls.get(dialect)
239        if not result:
240            raise ValueError(f"Unknown dialect '{dialect}'")
241
242        return result
243
244    @classmethod
245    def format_time(
246        cls, expression: t.Optional[str | exp.Expression]
247    ) -> t.Optional[exp.Expression]:
248        if isinstance(expression, str):
249            return exp.Literal.string(
250                # the time formats are quoted
251                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
252            )
253
254        if expression and expression.is_string:
255            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
256
257        return expression
258
259    @classmethod
260    def normalize_identifier(cls, expression: E) -> E:
261        """
262        Normalizes an unquoted identifier to either lower or upper case, thus essentially
263        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
264        they will be normalized to lowercase regardless of being quoted or not.
265        """
266        if isinstance(expression, exp.Identifier) and (
267            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
268        ):
269            expression.set(
270                "this",
271                expression.this.upper()
272                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
273                else expression.this.lower(),
274            )
275
276        return expression
277
278    @classmethod
279    def case_sensitive(cls, text: str) -> bool:
280        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
281        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
282            return False
283
284        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
285        return any(unsafe(char) for char in text)
286
287    @classmethod
288    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
289        """Checks if text can be identified given an identify option.
290
291        Args:
292            text: The text to check.
293            identify:
294                "always" or `True`: Always returns true.
295                "safe": True if the identifier is case-insensitive.
296
297        Returns:
298            Whether or not the given text can be identified.
299        """
300        if identify is True or identify == "always":
301            return True
302
303        if identify == "safe":
304            return not cls.case_sensitive(text)
305
306        return False
307
308    @classmethod
309    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
310        if isinstance(expression, exp.Identifier):
311            name = expression.this
312            expression.set(
313                "quoted",
314                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
315            )
316
317        return expression
318
319    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
320        return self.parser(**opts).parse(self.tokenize(sql), sql)
321
322    def parse_into(
323        self, expression_type: exp.IntoType, sql: str, **opts
324    ) -> t.List[t.Optional[exp.Expression]]:
325        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
326
327    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
328        return self.generator(**opts).generate(expression, copy=copy)
329
330    def transpile(self, sql: str, **opts) -> t.List[str]:
331        return [
332            self.generate(expression, copy=False, **opts) if expression else ""
333            for expression in self.parse(sql)
334        ]
335
336    def tokenize(self, sql: str) -> t.List[Token]:
337        return self.tokenizer.tokenize(sql)
338
339    @property
340    def tokenizer(self) -> Tokenizer:
341        if not hasattr(self, "_tokenizer"):
342            self._tokenizer = self.tokenizer_class()
343        return self._tokenizer
344
345    def parser(self, **opts) -> Parser:
346        return self.parser_class(**opts)
347
348    def generator(self, **opts) -> Generator:
349        return self.generator_class(**opts)
350
351
352DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
353
354
355def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
356    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
357
358
359def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
360    if expression.args.get("accuracy"):
361        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
362    return self.func("APPROX_COUNT_DISTINCT", expression.this)
363
364
365def if_sql(
366    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
367) -> t.Callable[[Generator, exp.If], str]:
368    def _if_sql(self: Generator, expression: exp.If) -> str:
369        return self.func(
370            name,
371            expression.this,
372            expression.args.get("true"),
373            expression.args.get("false") or false_value,
374        )
375
376    return _if_sql
377
378
379def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
380    return self.binary(expression, "->")
381
382
383def arrow_json_extract_scalar_sql(
384    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
385) -> str:
386    return self.binary(expression, "->>")
387
388
389def inline_array_sql(self: Generator, expression: exp.Array) -> str:
390    return f"[{self.expressions(expression, flat=True)}]"
391
392
393def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
394    return self.like_sql(
395        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
396    )
397
398
399def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
400    zone = self.sql(expression, "this")
401    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
402
403
404def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
405    if expression.args.get("recursive"):
406        self.unsupported("Recursive CTEs are unsupported")
407        expression.args["recursive"] = False
408    return self.with_sql(expression)
409
410
411def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
412    n = self.sql(expression, "this")
413    d = self.sql(expression, "expression")
414    return f"IF({d} <> 0, {n} / {d}, NULL)"
415
416
417def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
418    self.unsupported("TABLESAMPLE unsupported")
419    return self.sql(expression.this)
420
421
422def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
423    self.unsupported("PIVOT unsupported")
424    return ""
425
426
427def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
428    return self.cast_sql(expression)
429
430
431def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
432    self.unsupported("Properties unsupported")
433    return ""
434
435
436def no_comment_column_constraint_sql(
437    self: Generator, expression: exp.CommentColumnConstraint
438) -> str:
439    self.unsupported("CommentColumnConstraint unsupported")
440    return ""
441
442
443def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
444    self.unsupported("MAP_FROM_ENTRIES unsupported")
445    return ""
446
447
448def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
449    this = self.sql(expression, "this")
450    substr = self.sql(expression, "substr")
451    position = self.sql(expression, "position")
452    if position:
453        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
454    return f"STRPOS({this}, {substr})"
455
456
457def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
458    return (
459        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
460    )
461
462
463def var_map_sql(
464    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
465) -> str:
466    keys = expression.args["keys"]
467    values = expression.args["values"]
468
469    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
470        self.unsupported("Cannot convert array columns into map.")
471        return self.func(map_func_name, keys, values)
472
473    args = []
474    for key, value in zip(keys.expressions, values.expressions):
475        args.append(self.sql(key))
476        args.append(self.sql(value))
477
478    return self.func(map_func_name, *args)
479
480
481def format_time_lambda(
482    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
483) -> t.Callable[[t.List], E]:
484    """Helper used for time expressions.
485
486    Args:
487        exp_class: the expression class to instantiate.
488        dialect: target sql dialect.
489        default: the default format, True being time.
490
491    Returns:
492        A callable that can be used to return the appropriately formatted time expression.
493    """
494
495    def _format_time(args: t.List):
496        return exp_class(
497            this=seq_get(args, 0),
498            format=Dialect[dialect].format_time(
499                seq_get(args, 1)
500                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
501            ),
502        )
503
504    return _format_time
505
506
507def time_format(
508    dialect: DialectType = None,
509) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
510    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
511        """
512        Returns the time format for a given expression, unless it's equivalent
513        to the default time format of the dialect of interest.
514        """
515        time_format = self.format_time(expression)
516        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
517
518    return _time_format
519
520
521def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
522    """
523    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
524    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
525    columns are removed from the create statement.
526    """
527    has_schema = isinstance(expression.this, exp.Schema)
528    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
529
530    if has_schema and is_partitionable:
531        prop = expression.find(exp.PartitionedByProperty)
532        if prop and prop.this and not isinstance(prop.this, exp.Schema):
533            schema = expression.this
534            columns = {v.name.upper() for v in prop.this.expressions}
535            partitions = [col for col in schema.expressions if col.name.upper() in columns]
536            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
537            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
538            expression.set("this", schema)
539
540    return self.create_sql(expression)
541
542
543def parse_date_delta(
544    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
545) -> t.Callable[[t.List], E]:
546    def inner_func(args: t.List) -> E:
547        unit_based = len(args) == 3
548        this = args[2] if unit_based else seq_get(args, 0)
549        unit = args[0] if unit_based else exp.Literal.string("DAY")
550        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
551        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
552
553    return inner_func
554
555
556def parse_date_delta_with_interval(
557    expression_class: t.Type[E],
558) -> t.Callable[[t.List], t.Optional[E]]:
559    def func(args: t.List) -> t.Optional[E]:
560        if len(args) < 2:
561            return None
562
563        interval = args[1]
564
565        if not isinstance(interval, exp.Interval):
566            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
567
568        expression = interval.this
569        if expression and expression.is_string:
570            expression = exp.Literal.number(expression.this)
571
572        return expression_class(
573            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
574        )
575
576    return func
577
578
579def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
580    unit = seq_get(args, 0)
581    this = seq_get(args, 1)
582
583    if isinstance(this, exp.Cast) and this.is_type("date"):
584        return exp.DateTrunc(unit=unit, this=this)
585    return exp.TimestampTrunc(this=this, unit=unit)
586
587
588def date_add_interval_sql(
589    data_type: str, kind: str
590) -> t.Callable[[Generator, exp.Expression], str]:
591    def func(self: Generator, expression: exp.Expression) -> str:
592        this = self.sql(expression, "this")
593        unit = expression.args.get("unit")
594        unit = exp.var(unit.name.upper() if unit else "DAY")
595        interval = exp.Interval(this=expression.expression, unit=unit)
596        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
597
598    return func
599
600
601def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
602    return self.func(
603        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
604    )
605
606
607def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
608    if not expression.expression:
609        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
610    if expression.text("expression").lower() in TIMEZONES:
611        return self.sql(
612            exp.AtTimeZone(
613                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
614                zone=expression.expression,
615            )
616        )
617    return self.function_fallback_sql(expression)
618
619
620def locate_to_strposition(args: t.List) -> exp.Expression:
621    return exp.StrPosition(
622        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
623    )
624
625
626def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
627    return self.func(
628        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
629    )
630
631
632def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
633    return self.sql(
634        exp.Substring(
635            this=expression.this, start=exp.Literal.number(1), length=expression.expression
636        )
637    )
638
639
640def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
641    return self.sql(
642        exp.Substring(
643            this=expression.this,
644            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
645        )
646    )
647
648
649def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
650    return self.sql(exp.cast(expression.this, "timestamp"))
651
652
653def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
654    return self.sql(exp.cast(expression.this, "date"))
655
656
657# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
658def encode_decode_sql(
659    self: Generator, expression: exp.Expression, name: str, replace: bool = True
660) -> str:
661    charset = expression.args.get("charset")
662    if charset and charset.name.lower() != "utf-8":
663        self.unsupported(f"Expected utf-8 character set, got {charset}.")
664
665    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
666
667
668def min_or_least(self: Generator, expression: exp.Min) -> str:
669    name = "LEAST" if expression.expressions else "MIN"
670    return rename_func(name)(self, expression)
671
672
673def max_or_greatest(self: Generator, expression: exp.Max) -> str:
674    name = "GREATEST" if expression.expressions else "MAX"
675    return rename_func(name)(self, expression)
676
677
678def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
679    cond = expression.this
680
681    if isinstance(expression.this, exp.Distinct):
682        cond = expression.this.expressions[0]
683        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
684
685    return self.func("sum", exp.func("if", cond, 1, 0))
686
687
688def trim_sql(self: Generator, expression: exp.Trim) -> str:
689    target = self.sql(expression, "this")
690    trim_type = self.sql(expression, "position")
691    remove_chars = self.sql(expression, "expression")
692    collation = self.sql(expression, "collation")
693
694    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
695    if not remove_chars and not collation:
696        return self.trim_sql(expression)
697
698    trim_type = f"{trim_type} " if trim_type else ""
699    remove_chars = f"{remove_chars} " if remove_chars else ""
700    from_part = "FROM " if trim_type or remove_chars else ""
701    collation = f" COLLATE {collation}" if collation else ""
702    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
703
704
705def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
706    return self.func("STRPTIME", expression.this, self.format_time(expression))
707
708
709def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
710    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
711        _dialect = Dialect.get_or_raise(dialect)
712        time_format = self.format_time(expression)
713        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
714            return self.sql(
715                exp.cast(
716                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
717                    "date",
718                )
719            )
720        return self.sql(exp.cast(expression.this, "date"))
721
722    return _ts_or_ds_to_date_sql
723
724
725def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
726    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
727
728
729def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
730    delim, *rest_args = expression.expressions
731    return self.sql(
732        reduce(
733            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
734            rest_args,
735        )
736    )
737
738
739def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
740    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
741    if bad_args:
742        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
743
744    return self.func(
745        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
746    )
747
748
749def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
750    bad_args = list(
751        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
752    )
753    if bad_args:
754        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
755
756    return self.func(
757        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
758    )
759
760
761def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
762    names = []
763    for agg in aggregations:
764        if isinstance(agg, exp.Alias):
765            names.append(agg.alias)
766        else:
767            """
768            This case corresponds to aggregations without aliases being used as suffixes
769            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
770            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
771            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
772            """
773            agg_all_unquoted = agg.transform(
774                lambda node: exp.Identifier(this=node.name, quoted=False)
775                if isinstance(node, exp.Identifier)
776                else node
777            )
778            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
779
780    return names
781
782
783def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
784    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
785
786
787# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
788def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
789    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
790
791
792def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
793    return self.func("MAX", expression.this)
794
795
796def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
797    a = self.sql(expression.left)
798    b = self.sql(expression.right)
799    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
800
801
802# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
803def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
804    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
805
806
807def is_parse_json(expression: exp.Expression) -> bool:
808    return isinstance(expression, exp.ParseJSON) or (
809        isinstance(expression, exp.Cast) and expression.is_type("json")
810    )
811
812
813def isnull_to_is_null(args: t.List) -> exp.Expression:
814    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
815
816
817def generatedasidentitycolumnconstraint_sql(
818    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
819) -> str:
820    start = self.sql(expression, "start") or "1"
821    increment = self.sql(expression, "increment") or "1"
822    return f"IDENTITY({start}, {increment})"
823
824
825def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
826    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
827        if expression.args.get("count"):
828            self.unsupported(f"Only two arguments are supported in function {name}.")
829
830        return self.func(name, expression.this, expression.expression)
831
832    return _arg_max_or_min_sql
833
834
835def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
836    this = expression.this.copy()
837
838    return_type = expression.return_type
839    if return_type.is_type(exp.DataType.Type.DATE):
840        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
841        # can truncate timestamp strings, because some dialects can't cast them to DATE
842        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
843
844    expression.this.replace(exp.cast(this, return_type))
845    return expression
846
847
848def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
849    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
850        if cast and isinstance(expression, exp.TsOrDsAdd):
851            expression = ts_or_ds_add_cast(expression)
852
853        return self.func(
854            name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
855        )
856
857    return _delta_sql
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    DIALECT = ""
26
27    BIGQUERY = "bigquery"
28    CLICKHOUSE = "clickhouse"
29    DATABRICKS = "databricks"
30    DRILL = "drill"
31    DUCKDB = "duckdb"
32    HIVE = "hive"
33    MYSQL = "mysql"
34    ORACLE = "oracle"
35    POSTGRES = "postgres"
36    PRESTO = "presto"
37    REDSHIFT = "redshift"
38    SNOWFLAKE = "snowflake"
39    SPARK = "spark"
40    SPARK2 = "spark2"
41    SQLITE = "sqlite"
42    STARROCKS = "starrocks"
43    TABLEAU = "tableau"
44    TERADATA = "teradata"
45    TRINO = "trino"
46    TSQL = "tsql"
47    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
142class Dialect(metaclass=_Dialect):
143    # Determines the base index offset for arrays
144    INDEX_OFFSET = 0
145
146    # If true unnest table aliases are considered only as column aliases
147    UNNEST_COLUMN_ONLY = False
148
149    # Determines whether or not the table alias comes after tablesample
150    ALIAS_POST_TABLESAMPLE = False
151
152    # Determines whether or not unquoted identifiers are resolved as uppercase
153    # When set to None, it means that the dialect treats all identifiers as case-insensitive
154    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
155
156    # Determines whether or not an unquoted identifier can start with a digit
157    IDENTIFIERS_CAN_START_WITH_DIGIT = False
158
159    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
160    DPIPE_IS_STRING_CONCAT = True
161
162    # Determines whether or not CONCAT's arguments must be strings
163    STRICT_STRING_CONCAT = False
164
165    # Determines whether or not user-defined data types are supported
166    SUPPORTS_USER_DEFINED_TYPES = True
167
168    # Determines whether or not SEMI/ANTI JOINs are supported
169    SUPPORTS_SEMI_ANTI_JOIN = True
170
171    # Determines how function names are going to be normalized
172    NORMALIZE_FUNCTIONS: bool | str = "upper"
173
174    # Determines whether the base comes first in the LOG function
175    LOG_BASE_FIRST = True
176
177    # Indicates the default null ordering method to use if not explicitly set
178    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
179    NULL_ORDERING = "nulls_are_small"
180
181    # Whether the behavior of a / b depends on the types of a and b.
182    # False means a / b is always float division.
183    # True means a / b is integer division if both a and b are integers.
184    TYPED_DIVISION = False
185
186    # False means 1 / 0 throws an error.
187    # True means 1 / 0 returns null.
188    SAFE_DIVISION = False
189
190    DATE_FORMAT = "'%Y-%m-%d'"
191    DATEINT_FORMAT = "'%Y%m%d'"
192    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
193
194    # Custom time mappings in which the key represents dialect time format
195    # and the value represents a python time format
196    TIME_MAPPING: t.Dict[str, str] = {}
197
198    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
199    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
200    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
201    FORMAT_MAPPING: t.Dict[str, str] = {}
202
203    # Mapping of an unescaped escape sequence to the corresponding character
204    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
205
206    # Columns that are auto-generated by the engine corresponding to this dialect
207    # Such columns may be excluded from SELECT * queries, for example
208    PSEUDOCOLUMNS: t.Set[str] = set()
209
210    # Autofilled
211    tokenizer_class = Tokenizer
212    parser_class = Parser
213    generator_class = Generator
214
215    # A trie of the time_mapping keys
216    TIME_TRIE: t.Dict = {}
217    FORMAT_TRIE: t.Dict = {}
218
219    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
220    INVERSE_TIME_TRIE: t.Dict = {}
221
222    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
223
224    def __eq__(self, other: t.Any) -> bool:
225        return type(self) == other
226
227    def __hash__(self) -> int:
228        return hash(type(self))
229
230    @classmethod
231    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
232        if not dialect:
233            return cls
234        if isinstance(dialect, _Dialect):
235            return dialect
236        if isinstance(dialect, Dialect):
237            return dialect.__class__
238
239        result = cls.get(dialect)
240        if not result:
241            raise ValueError(f"Unknown dialect '{dialect}'")
242
243        return result
244
245    @classmethod
246    def format_time(
247        cls, expression: t.Optional[str | exp.Expression]
248    ) -> t.Optional[exp.Expression]:
249        if isinstance(expression, str):
250            return exp.Literal.string(
251                # the time formats are quoted
252                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
253            )
254
255        if expression and expression.is_string:
256            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
257
258        return expression
259
260    @classmethod
261    def normalize_identifier(cls, expression: E) -> E:
262        """
263        Normalizes an unquoted identifier to either lower or upper case, thus essentially
264        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
265        they will be normalized to lowercase regardless of being quoted or not.
266        """
267        if isinstance(expression, exp.Identifier) and (
268            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
269        ):
270            expression.set(
271                "this",
272                expression.this.upper()
273                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
274                else expression.this.lower(),
275            )
276
277        return expression
278
279    @classmethod
280    def case_sensitive(cls, text: str) -> bool:
281        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
282        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
283            return False
284
285        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
286        return any(unsafe(char) for char in text)
287
288    @classmethod
289    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
290        """Checks if text can be identified given an identify option.
291
292        Args:
293            text: The text to check.
294            identify:
295                "always" or `True`: Always returns true.
296                "safe": True if the identifier is case-insensitive.
297
298        Returns:
299            Whether or not the given text can be identified.
300        """
301        if identify is True or identify == "always":
302            return True
303
304        if identify == "safe":
305            return not cls.case_sensitive(text)
306
307        return False
308
309    @classmethod
310    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
311        if isinstance(expression, exp.Identifier):
312            name = expression.this
313            expression.set(
314                "quoted",
315                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
316            )
317
318        return expression
319
320    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
321        return self.parser(**opts).parse(self.tokenize(sql), sql)
322
323    def parse_into(
324        self, expression_type: exp.IntoType, sql: str, **opts
325    ) -> t.List[t.Optional[exp.Expression]]:
326        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
327
328    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
329        return self.generator(**opts).generate(expression, copy=copy)
330
331    def transpile(self, sql: str, **opts) -> t.List[str]:
332        return [
333            self.generate(expression, copy=False, **opts) if expression else ""
334            for expression in self.parse(sql)
335        ]
336
337    def tokenize(self, sql: str) -> t.List[Token]:
338        return self.tokenizer.tokenize(sql)
339
340    @property
341    def tokenizer(self) -> Tokenizer:
342        if not hasattr(self, "_tokenizer"):
343            self._tokenizer = self.tokenizer_class()
344        return self._tokenizer
345
346    def parser(self, **opts) -> Parser:
347        return self.parser_class(**opts)
348
349    def generator(self, **opts) -> Generator:
350        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
LOG_BASE_FIRST = True
NULL_ORDERING = 'nulls_are_small'
TYPED_DIVISION = False
SAFE_DIVISION = False
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
ESCAPE_SEQUENCES: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
230    @classmethod
231    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
232        if not dialect:
233            return cls
234        if isinstance(dialect, _Dialect):
235            return dialect
236        if isinstance(dialect, Dialect):
237            return dialect.__class__
238
239        result = cls.get(dialect)
240        if not result:
241            raise ValueError(f"Unknown dialect '{dialect}'")
242
243        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
245    @classmethod
246    def format_time(
247        cls, expression: t.Optional[str | exp.Expression]
248    ) -> t.Optional[exp.Expression]:
249        if isinstance(expression, str):
250            return exp.Literal.string(
251                # the time formats are quoted
252                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
253            )
254
255        if expression and expression.is_string:
256            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
257
258        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
260    @classmethod
261    def normalize_identifier(cls, expression: E) -> E:
262        """
263        Normalizes an unquoted identifier to either lower or upper case, thus essentially
264        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
265        they will be normalized to lowercase regardless of being quoted or not.
266        """
267        if isinstance(expression, exp.Identifier) and (
268            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
269        ):
270            expression.set(
271                "this",
272                expression.this.upper()
273                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
274                else expression.this.lower(),
275            )
276
277        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
279    @classmethod
280    def case_sensitive(cls, text: str) -> bool:
281        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
282        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
283            return False
284
285        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
286        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
288    @classmethod
289    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
290        """Checks if text can be identified given an identify option.
291
292        Args:
293            text: The text to check.
294            identify:
295                "always" or `True`: Always returns true.
296                "safe": True if the identifier is case-insensitive.
297
298        Returns:
299            Whether or not the given text can be identified.
300        """
301        if identify is True or identify == "always":
302            return True
303
304        if identify == "safe":
305            return not cls.case_sensitive(text)
306
307        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
309    @classmethod
310    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
311        if isinstance(expression, exp.Identifier):
312            name = expression.this
313            expression.set(
314                "quoted",
315                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
316            )
317
318        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
320    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
321        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
323    def parse_into(
324        self, expression_type: exp.IntoType, sql: str, **opts
325    ) -> t.List[t.Optional[exp.Expression]]:
326        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
328    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
329        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
331    def transpile(self, sql: str, **opts) -> t.List[str]:
332        return [
333            self.generate(expression, copy=False, **opts) if expression else ""
334            for expression in self.parse(sql)
335        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
337    def tokenize(self, sql: str) -> t.List[Token]:
338        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
346    def parser(self, **opts) -> Parser:
347        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
349    def generator(self, **opts) -> Generator:
350        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
356def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
357    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
360def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
361    if expression.args.get("accuracy"):
362        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
363    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
366def if_sql(
367    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
368) -> t.Callable[[Generator, exp.If], str]:
369    def _if_sql(self: Generator, expression: exp.If) -> str:
370        return self.func(
371            name,
372            expression.this,
373            expression.args.get("true"),
374            expression.args.get("false") or false_value,
375        )
376
377    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
380def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
381    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
384def arrow_json_extract_scalar_sql(
385    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
386) -> str:
387    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
390def inline_array_sql(self: Generator, expression: exp.Array) -> str:
391    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
394def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
395    return self.like_sql(
396        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
397    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
400def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
401    zone = self.sql(expression, "this")
402    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
405def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
406    if expression.args.get("recursive"):
407        self.unsupported("Recursive CTEs are unsupported")
408        expression.args["recursive"] = False
409    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
412def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
413    n = self.sql(expression, "this")
414    d = self.sql(expression, "expression")
415    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
418def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
419    self.unsupported("TABLESAMPLE unsupported")
420    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
423def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
424    self.unsupported("PIVOT unsupported")
425    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
428def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
429    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
432def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
433    self.unsupported("Properties unsupported")
434    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
437def no_comment_column_constraint_sql(
438    self: Generator, expression: exp.CommentColumnConstraint
439) -> str:
440    self.unsupported("CommentColumnConstraint unsupported")
441    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
444def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
445    self.unsupported("MAP_FROM_ENTRIES unsupported")
446    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
449def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
450    this = self.sql(expression, "this")
451    substr = self.sql(expression, "substr")
452    position = self.sql(expression, "position")
453    if position:
454        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
455    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
458def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
459    return (
460        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
461    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
464def var_map_sql(
465    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
466) -> str:
467    keys = expression.args["keys"]
468    values = expression.args["values"]
469
470    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
471        self.unsupported("Cannot convert array columns into map.")
472        return self.func(map_func_name, keys, values)
473
474    args = []
475    for key, value in zip(keys.expressions, values.expressions):
476        args.append(self.sql(key))
477        args.append(self.sql(value))
478
479    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
482def format_time_lambda(
483    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
484) -> t.Callable[[t.List], E]:
485    """Helper used for time expressions.
486
487    Args:
488        exp_class: the expression class to instantiate.
489        dialect: target sql dialect.
490        default: the default format, True being time.
491
492    Returns:
493        A callable that can be used to return the appropriately formatted time expression.
494    """
495
496    def _format_time(args: t.List):
497        return exp_class(
498            this=seq_get(args, 0),
499            format=Dialect[dialect].format_time(
500                seq_get(args, 1)
501                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
502            ),
503        )
504
505    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
508def time_format(
509    dialect: DialectType = None,
510) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
511    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
512        """
513        Returns the time format for a given expression, unless it's equivalent
514        to the default time format of the dialect of interest.
515        """
516        time_format = self.format_time(expression)
517        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
518
519    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
522def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
523    """
524    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
525    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
526    columns are removed from the create statement.
527    """
528    has_schema = isinstance(expression.this, exp.Schema)
529    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
530
531    if has_schema and is_partitionable:
532        prop = expression.find(exp.PartitionedByProperty)
533        if prop and prop.this and not isinstance(prop.this, exp.Schema):
534            schema = expression.this
535            columns = {v.name.upper() for v in prop.this.expressions}
536            partitions = [col for col in schema.expressions if col.name.upper() in columns]
537            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
538            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
539            expression.set("this", schema)
540
541    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
544def parse_date_delta(
545    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
546) -> t.Callable[[t.List], E]:
547    def inner_func(args: t.List) -> E:
548        unit_based = len(args) == 3
549        this = args[2] if unit_based else seq_get(args, 0)
550        unit = args[0] if unit_based else exp.Literal.string("DAY")
551        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
552        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
553
554    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
557def parse_date_delta_with_interval(
558    expression_class: t.Type[E],
559) -> t.Callable[[t.List], t.Optional[E]]:
560    def func(args: t.List) -> t.Optional[E]:
561        if len(args) < 2:
562            return None
563
564        interval = args[1]
565
566        if not isinstance(interval, exp.Interval):
567            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
568
569        expression = interval.this
570        if expression and expression.is_string:
571            expression = exp.Literal.number(expression.this)
572
573        return expression_class(
574            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
575        )
576
577    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
580def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
581    unit = seq_get(args, 0)
582    this = seq_get(args, 1)
583
584    if isinstance(this, exp.Cast) and this.is_type("date"):
585        return exp.DateTrunc(unit=unit, this=this)
586    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
589def date_add_interval_sql(
590    data_type: str, kind: str
591) -> t.Callable[[Generator, exp.Expression], str]:
592    def func(self: Generator, expression: exp.Expression) -> str:
593        this = self.sql(expression, "this")
594        unit = expression.args.get("unit")
595        unit = exp.var(unit.name.upper() if unit else "DAY")
596        interval = exp.Interval(this=expression.expression, unit=unit)
597        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
598
599    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
602def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
603    return self.func(
604        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
605    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
608def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
609    if not expression.expression:
610        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
611    if expression.text("expression").lower() in TIMEZONES:
612        return self.sql(
613            exp.AtTimeZone(
614                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
615                zone=expression.expression,
616            )
617        )
618    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
621def locate_to_strposition(args: t.List) -> exp.Expression:
622    return exp.StrPosition(
623        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
624    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
627def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
628    return self.func(
629        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
630    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
633def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
634    return self.sql(
635        exp.Substring(
636            this=expression.this, start=exp.Literal.number(1), length=expression.expression
637        )
638    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
641def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
642    return self.sql(
643        exp.Substring(
644            this=expression.this,
645            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
646        )
647    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
650def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
651    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
654def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
655    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
659def encode_decode_sql(
660    self: Generator, expression: exp.Expression, name: str, replace: bool = True
661) -> str:
662    charset = expression.args.get("charset")
663    if charset and charset.name.lower() != "utf-8":
664        self.unsupported(f"Expected utf-8 character set, got {charset}.")
665
666    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
669def min_or_least(self: Generator, expression: exp.Min) -> str:
670    name = "LEAST" if expression.expressions else "MIN"
671    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
674def max_or_greatest(self: Generator, expression: exp.Max) -> str:
675    name = "GREATEST" if expression.expressions else "MAX"
676    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
679def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
680    cond = expression.this
681
682    if isinstance(expression.this, exp.Distinct):
683        cond = expression.this.expressions[0]
684        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
685
686    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
689def trim_sql(self: Generator, expression: exp.Trim) -> str:
690    target = self.sql(expression, "this")
691    trim_type = self.sql(expression, "position")
692    remove_chars = self.sql(expression, "expression")
693    collation = self.sql(expression, "collation")
694
695    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
696    if not remove_chars and not collation:
697        return self.trim_sql(expression)
698
699    trim_type = f"{trim_type} " if trim_type else ""
700    remove_chars = f"{remove_chars} " if remove_chars else ""
701    from_part = "FROM " if trim_type or remove_chars else ""
702    collation = f" COLLATE {collation}" if collation else ""
703    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
706def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
707    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
710def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
711    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
712        _dialect = Dialect.get_or_raise(dialect)
713        time_format = self.format_time(expression)
714        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
715            return self.sql(
716                exp.cast(
717                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
718                    "date",
719                )
720            )
721        return self.sql(exp.cast(expression.this, "date"))
722
723    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
726def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
727    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
730def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
731    delim, *rest_args = expression.expressions
732    return self.sql(
733        reduce(
734            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
735            rest_args,
736        )
737    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
740def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
741    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
742    if bad_args:
743        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
744
745    return self.func(
746        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
747    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
750def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
751    bad_args = list(
752        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
753    )
754    if bad_args:
755        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
756
757    return self.func(
758        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
759    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
762def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
763    names = []
764    for agg in aggregations:
765        if isinstance(agg, exp.Alias):
766            names.append(agg.alias)
767        else:
768            """
769            This case corresponds to aggregations without aliases being used as suffixes
770            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
771            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
772            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
773            """
774            agg_all_unquoted = agg.transform(
775                lambda node: exp.Identifier(this=node.name, quoted=False)
776                if isinstance(node, exp.Identifier)
777                else node
778            )
779            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
780
781    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
784def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
785    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
789def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
790    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
793def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
794    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
797def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
798    a = self.sql(expression.left)
799    b = self.sql(expression.right)
800    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
804def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
805    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
808def is_parse_json(expression: exp.Expression) -> bool:
809    return isinstance(expression, exp.ParseJSON) or (
810        isinstance(expression, exp.Cast) and expression.is_type("json")
811    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
814def isnull_to_is_null(args: t.List) -> exp.Expression:
815    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
818def generatedasidentitycolumnconstraint_sql(
819    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
820) -> str:
821    start = self.sql(expression, "start") or "1"
822    increment = self.sql(expression, "increment") or "1"
823    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
826def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
827    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
828        if expression.args.get("count"):
829            self.unsupported(f"Only two arguments are supported in function {name}.")
830
831        return self.func(name, expression.this, expression.expression)
832
833    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
836def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
837    this = expression.this.copy()
838
839    return_type = expression.return_type
840    if return_type.is_type(exp.DataType.Type.DATE):
841        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
842        # can truncate timestamp strings, because some dialects can't cast them to DATE
843        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
844
845    expression.this.replace(exp.cast(this, return_type))
846    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
849def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
850    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
851        if cast and isinstance(expression, exp.TsOrDsAdd):
852            expression = ts_or_ds_add_cast(expression)
853
854        return self.func(
855            name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
856        )
857
858    return _delta_sql