Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19
 20class Dialects(str, Enum):
 21    DIALECT = ""
 22
 23    BIGQUERY = "bigquery"
 24    CLICKHOUSE = "clickhouse"
 25    DATABRICKS = "databricks"
 26    DRILL = "drill"
 27    DUCKDB = "duckdb"
 28    HIVE = "hive"
 29    MYSQL = "mysql"
 30    ORACLE = "oracle"
 31    POSTGRES = "postgres"
 32    PRESTO = "presto"
 33    REDSHIFT = "redshift"
 34    SNOWFLAKE = "snowflake"
 35    SPARK = "spark"
 36    SPARK2 = "spark2"
 37    SQLITE = "sqlite"
 38    STARROCKS = "starrocks"
 39    TABLEAU = "tableau"
 40    TERADATA = "teradata"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    Doris = "doris"
 44
 45
 46class _Dialect(type):
 47    classes: t.Dict[str, t.Type[Dialect]] = {}
 48
 49    def __eq__(cls, other: t.Any) -> bool:
 50        if cls is other:
 51            return True
 52        if isinstance(other, str):
 53            return cls is cls.get(other)
 54        if isinstance(other, Dialect):
 55            return cls is type(other)
 56
 57        return False
 58
 59    def __hash__(cls) -> int:
 60        return hash(cls.__name__.lower())
 61
 62    @classmethod
 63    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 64        return cls.classes[key]
 65
 66    @classmethod
 67    def get(
 68        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 69    ) -> t.Optional[t.Type[Dialect]]:
 70        return cls.classes.get(key, default)
 71
 72    def __new__(cls, clsname, bases, attrs):
 73        klass = super().__new__(cls, clsname, bases, attrs)
 74        enum = Dialects.__members__.get(clsname.upper())
 75        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 76
 77        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 78        klass.FORMAT_TRIE = (
 79            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
 80        )
 81        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
 82        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
 83
 84        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 85        klass.parser_class = getattr(klass, "Parser", Parser)
 86        klass.generator_class = getattr(klass, "Generator", Generator)
 87
 88        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
 89        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
 90            klass.tokenizer_class._IDENTIFIERS.items()
 91        )[0]
 92
 93        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 94            return next(
 95                (
 96                    (s, e)
 97                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 98                    if t == token_type
 99                ),
100                (None, None),
101            )
102
103        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
104        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
105        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
106
107        dialect_properties = {
108            **{
109                k: v
110                for k, v in vars(klass).items()
111                if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
112            },
113            "TOKENIZER_CLASS": klass.tokenizer_class,
114        }
115
116        if enum not in ("", "bigquery"):
117            dialect_properties["SELECT_KINDS"] = ()
118
119        # Pass required dialect properties to the tokenizer, parser and generator classes
120        for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
121            for name, value in dialect_properties.items():
122                if hasattr(subclass, name):
123                    setattr(subclass, name, value)
124
125        if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
126            klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
127
128        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
129            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
130                TokenType.ANTI,
131                TokenType.SEMI,
132            }
133
134        klass.generator_class.can_identify = klass.can_identify
135
136        return klass
137
138
139class Dialect(metaclass=_Dialect):
140    # Determines the base index offset for arrays
141    INDEX_OFFSET = 0
142
143    # If true unnest table aliases are considered only as column aliases
144    UNNEST_COLUMN_ONLY = False
145
146    # Determines whether or not the table alias comes after tablesample
147    ALIAS_POST_TABLESAMPLE = False
148
149    # Determines whether or not unquoted identifiers are resolved as uppercase
150    # When set to None, it means that the dialect treats all identifiers as case-insensitive
151    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
152
153    # Determines whether or not an unquoted identifier can start with a digit
154    IDENTIFIERS_CAN_START_WITH_DIGIT = False
155
156    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
157    DPIPE_IS_STRING_CONCAT = True
158
159    # Determines whether or not CONCAT's arguments must be strings
160    STRICT_STRING_CONCAT = False
161
162    # Determines whether or not user-defined data types are supported
163    SUPPORTS_USER_DEFINED_TYPES = True
164
165    # Determines whether or not SEMI/ANTI JOINs are supported
166    SUPPORTS_SEMI_ANTI_JOIN = True
167
168    # Determines how function names are going to be normalized
169    NORMALIZE_FUNCTIONS: bool | str = "upper"
170
171    # Determines whether the base comes first in the LOG function
172    LOG_BASE_FIRST = True
173
174    # Indicates the default null ordering method to use if not explicitly set
175    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
176    NULL_ORDERING = "nulls_are_small"
177
178    DATE_FORMAT = "'%Y-%m-%d'"
179    DATEINT_FORMAT = "'%Y%m%d'"
180    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
181
182    # Custom time mappings in which the key represents dialect time format
183    # and the value represents a python time format
184    TIME_MAPPING: t.Dict[str, str] = {}
185
186    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
187    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
188    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
189    FORMAT_MAPPING: t.Dict[str, str] = {}
190
191    # Columns that are auto-generated by the engine corresponding to this dialect
192    # Such columns may be excluded from SELECT * queries, for example
193    PSEUDOCOLUMNS: t.Set[str] = set()
194
195    # Autofilled
196    tokenizer_class = Tokenizer
197    parser_class = Parser
198    generator_class = Generator
199
200    # A trie of the time_mapping keys
201    TIME_TRIE: t.Dict = {}
202    FORMAT_TRIE: t.Dict = {}
203
204    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
205    INVERSE_TIME_TRIE: t.Dict = {}
206
207    def __eq__(self, other: t.Any) -> bool:
208        return type(self) == other
209
210    def __hash__(self) -> int:
211        return hash(type(self))
212
213    @classmethod
214    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
215        if not dialect:
216            return cls
217        if isinstance(dialect, _Dialect):
218            return dialect
219        if isinstance(dialect, Dialect):
220            return dialect.__class__
221
222        result = cls.get(dialect)
223        if not result:
224            raise ValueError(f"Unknown dialect '{dialect}'")
225
226        return result
227
228    @classmethod
229    def format_time(
230        cls, expression: t.Optional[str | exp.Expression]
231    ) -> t.Optional[exp.Expression]:
232        if isinstance(expression, str):
233            return exp.Literal.string(
234                # the time formats are quoted
235                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
236            )
237
238        if expression and expression.is_string:
239            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
240
241        return expression
242
243    @classmethod
244    def normalize_identifier(cls, expression: E) -> E:
245        """
246        Normalizes an unquoted identifier to either lower or upper case, thus essentially
247        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
248        they will be normalized to lowercase regardless of being quoted or not.
249        """
250        if isinstance(expression, exp.Identifier) and (
251            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
252        ):
253            expression.set(
254                "this",
255                expression.this.upper()
256                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
257                else expression.this.lower(),
258            )
259
260        return expression
261
262    @classmethod
263    def case_sensitive(cls, text: str) -> bool:
264        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
265        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
266            return False
267
268        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
269        return any(unsafe(char) for char in text)
270
271    @classmethod
272    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
273        """Checks if text can be identified given an identify option.
274
275        Args:
276            text: The text to check.
277            identify:
278                "always" or `True`: Always returns true.
279                "safe": True if the identifier is case-insensitive.
280
281        Returns:
282            Whether or not the given text can be identified.
283        """
284        if identify is True or identify == "always":
285            return True
286
287        if identify == "safe":
288            return not cls.case_sensitive(text)
289
290        return False
291
292    @classmethod
293    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
294        if isinstance(expression, exp.Identifier):
295            name = expression.this
296            expression.set(
297                "quoted",
298                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
299            )
300
301        return expression
302
303    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
304        return self.parser(**opts).parse(self.tokenize(sql), sql)
305
306    def parse_into(
307        self, expression_type: exp.IntoType, sql: str, **opts
308    ) -> t.List[t.Optional[exp.Expression]]:
309        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
310
311    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
312        return self.generator(**opts).generate(expression)
313
314    def transpile(self, sql: str, **opts) -> t.List[str]:
315        return [self.generate(expression, **opts) for expression in self.parse(sql)]
316
317    def tokenize(self, sql: str) -> t.List[Token]:
318        return self.tokenizer.tokenize(sql)
319
320    @property
321    def tokenizer(self) -> Tokenizer:
322        if not hasattr(self, "_tokenizer"):
323            self._tokenizer = self.tokenizer_class()
324        return self._tokenizer
325
326    def parser(self, **opts) -> Parser:
327        return self.parser_class(**opts)
328
329    def generator(self, **opts) -> Generator:
330        return self.generator_class(**opts)
331
332
333DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
334
335
336def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
337    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
338
339
340def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
341    if expression.args.get("accuracy"):
342        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
343    return self.func("APPROX_COUNT_DISTINCT", expression.this)
344
345
346def if_sql(
347    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
348) -> t.Callable[[Generator, exp.If], str]:
349    def _if_sql(self: Generator, expression: exp.If) -> str:
350        return self.func(
351            name,
352            expression.this,
353            expression.args.get("true"),
354            expression.args.get("false") or false_value,
355        )
356
357    return _if_sql
358
359
360def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
361    return self.binary(expression, "->")
362
363
364def arrow_json_extract_scalar_sql(
365    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
366) -> str:
367    return self.binary(expression, "->>")
368
369
370def inline_array_sql(self: Generator, expression: exp.Array) -> str:
371    return f"[{self.expressions(expression, flat=True)}]"
372
373
374def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
375    return self.like_sql(
376        exp.Like(
377            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
378        )
379    )
380
381
382def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
383    zone = self.sql(expression, "this")
384    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
385
386
387def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
388    if expression.args.get("recursive"):
389        self.unsupported("Recursive CTEs are unsupported")
390        expression.args["recursive"] = False
391    return self.with_sql(expression)
392
393
394def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
395    n = self.sql(expression, "this")
396    d = self.sql(expression, "expression")
397    return f"IF({d} <> 0, {n} / {d}, NULL)"
398
399
400def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
401    self.unsupported("TABLESAMPLE unsupported")
402    return self.sql(expression.this)
403
404
405def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
406    self.unsupported("PIVOT unsupported")
407    return ""
408
409
410def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
411    return self.cast_sql(expression)
412
413
414def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
415    self.unsupported("Properties unsupported")
416    return ""
417
418
419def no_comment_column_constraint_sql(
420    self: Generator, expression: exp.CommentColumnConstraint
421) -> str:
422    self.unsupported("CommentColumnConstraint unsupported")
423    return ""
424
425
426def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
427    self.unsupported("MAP_FROM_ENTRIES unsupported")
428    return ""
429
430
431def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
432    this = self.sql(expression, "this")
433    substr = self.sql(expression, "substr")
434    position = self.sql(expression, "position")
435    if position:
436        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
437    return f"STRPOS({this}, {substr})"
438
439
440def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
441    return (
442        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
443    )
444
445
446def var_map_sql(
447    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
448) -> str:
449    keys = expression.args["keys"]
450    values = expression.args["values"]
451
452    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
453        self.unsupported("Cannot convert array columns into map.")
454        return self.func(map_func_name, keys, values)
455
456    args = []
457    for key, value in zip(keys.expressions, values.expressions):
458        args.append(self.sql(key))
459        args.append(self.sql(value))
460
461    return self.func(map_func_name, *args)
462
463
464def format_time_lambda(
465    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
466) -> t.Callable[[t.List], E]:
467    """Helper used for time expressions.
468
469    Args:
470        exp_class: the expression class to instantiate.
471        dialect: target sql dialect.
472        default: the default format, True being time.
473
474    Returns:
475        A callable that can be used to return the appropriately formatted time expression.
476    """
477
478    def _format_time(args: t.List):
479        return exp_class(
480            this=seq_get(args, 0),
481            format=Dialect[dialect].format_time(
482                seq_get(args, 1)
483                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
484            ),
485        )
486
487    return _format_time
488
489
490def time_format(
491    dialect: DialectType = None,
492) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
493    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
494        """
495        Returns the time format for a given expression, unless it's equivalent
496        to the default time format of the dialect of interest.
497        """
498        time_format = self.format_time(expression)
499        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
500
501    return _time_format
502
503
504def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
505    """
506    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
507    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
508    columns are removed from the create statement.
509    """
510    has_schema = isinstance(expression.this, exp.Schema)
511    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
512
513    if has_schema and is_partitionable:
514        expression = expression.copy()
515        prop = expression.find(exp.PartitionedByProperty)
516        if prop and prop.this and not isinstance(prop.this, exp.Schema):
517            schema = expression.this
518            columns = {v.name.upper() for v in prop.this.expressions}
519            partitions = [col for col in schema.expressions if col.name.upper() in columns]
520            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
521            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
522            expression.set("this", schema)
523
524    return self.create_sql(expression)
525
526
527def parse_date_delta(
528    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
529) -> t.Callable[[t.List], E]:
530    def inner_func(args: t.List) -> E:
531        unit_based = len(args) == 3
532        this = args[2] if unit_based else seq_get(args, 0)
533        unit = args[0] if unit_based else exp.Literal.string("DAY")
534        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
535        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
536
537    return inner_func
538
539
540def parse_date_delta_with_interval(
541    expression_class: t.Type[E],
542) -> t.Callable[[t.List], t.Optional[E]]:
543    def func(args: t.List) -> t.Optional[E]:
544        if len(args) < 2:
545            return None
546
547        interval = args[1]
548
549        if not isinstance(interval, exp.Interval):
550            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
551
552        expression = interval.this
553        if expression and expression.is_string:
554            expression = exp.Literal.number(expression.this)
555
556        return expression_class(
557            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
558        )
559
560    return func
561
562
563def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
564    unit = seq_get(args, 0)
565    this = seq_get(args, 1)
566
567    if isinstance(this, exp.Cast) and this.is_type("date"):
568        return exp.DateTrunc(unit=unit, this=this)
569    return exp.TimestampTrunc(this=this, unit=unit)
570
571
572def date_add_interval_sql(
573    data_type: str, kind: str
574) -> t.Callable[[Generator, exp.Expression], str]:
575    def func(self: Generator, expression: exp.Expression) -> str:
576        this = self.sql(expression, "this")
577        unit = expression.args.get("unit")
578        unit = exp.var(unit.name.upper() if unit else "DAY")
579        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
580        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
581
582    return func
583
584
585def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
586    return self.func(
587        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
588    )
589
590
591def locate_to_strposition(args: t.List) -> exp.Expression:
592    return exp.StrPosition(
593        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
594    )
595
596
597def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
598    return self.func(
599        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
600    )
601
602
603def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
604    expression = expression.copy()
605    return self.sql(
606        exp.Substring(
607            this=expression.this, start=exp.Literal.number(1), length=expression.expression
608        )
609    )
610
611
612def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
613    expression = expression.copy()
614    return self.sql(
615        exp.Substring(
616            this=expression.this,
617            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
618        )
619    )
620
621
622def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
623    return self.sql(exp.cast(expression.this, "timestamp"))
624
625
626def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
627    return self.sql(exp.cast(expression.this, "date"))
628
629
630# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
631def encode_decode_sql(
632    self: Generator, expression: exp.Expression, name: str, replace: bool = True
633) -> str:
634    charset = expression.args.get("charset")
635    if charset and charset.name.lower() != "utf-8":
636        self.unsupported(f"Expected utf-8 character set, got {charset}.")
637
638    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
639
640
641def min_or_least(self: Generator, expression: exp.Min) -> str:
642    name = "LEAST" if expression.expressions else "MIN"
643    return rename_func(name)(self, expression)
644
645
646def max_or_greatest(self: Generator, expression: exp.Max) -> str:
647    name = "GREATEST" if expression.expressions else "MAX"
648    return rename_func(name)(self, expression)
649
650
651def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
652    cond = expression.this
653
654    if isinstance(expression.this, exp.Distinct):
655        cond = expression.this.expressions[0]
656        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
657
658    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
659
660
661def trim_sql(self: Generator, expression: exp.Trim) -> str:
662    target = self.sql(expression, "this")
663    trim_type = self.sql(expression, "position")
664    remove_chars = self.sql(expression, "expression")
665    collation = self.sql(expression, "collation")
666
667    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
668    if not remove_chars and not collation:
669        return self.trim_sql(expression)
670
671    trim_type = f"{trim_type} " if trim_type else ""
672    remove_chars = f"{remove_chars} " if remove_chars else ""
673    from_part = "FROM " if trim_type or remove_chars else ""
674    collation = f" COLLATE {collation}" if collation else ""
675    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
676
677
678def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
679    return self.func("STRPTIME", expression.this, self.format_time(expression))
680
681
682def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
683    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
684        _dialect = Dialect.get_or_raise(dialect)
685        time_format = self.format_time(expression)
686        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
687            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
688
689        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
690
691    return _ts_or_ds_to_date_sql
692
693
694def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
695    expression = expression.copy()
696    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
697
698
699def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
700    expression = expression.copy()
701    delim, *rest_args = expression.expressions
702    return self.sql(
703        reduce(
704            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
705            rest_args,
706        )
707    )
708
709
710def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
711    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
712    if bad_args:
713        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
714
715    return self.func(
716        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
717    )
718
719
720def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
721    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
722    if bad_args:
723        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
724
725    return self.func(
726        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
727    )
728
729
730def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
731    names = []
732    for agg in aggregations:
733        if isinstance(agg, exp.Alias):
734            names.append(agg.alias)
735        else:
736            """
737            This case corresponds to aggregations without aliases being used as suffixes
738            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
739            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
740            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
741            """
742            agg_all_unquoted = agg.transform(
743                lambda node: exp.Identifier(this=node.name, quoted=False)
744                if isinstance(node, exp.Identifier)
745                else node
746            )
747            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
748
749    return names
750
751
752def simplify_literal(expression: E) -> E:
753    if not isinstance(expression.expression, exp.Literal):
754        from sqlglot.optimizer.simplify import simplify
755
756        simplify(expression.expression)
757
758    return expression
759
760
761def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
762    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
763
764
765# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
766def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
767    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
768
769
770def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
771    return self.func("MAX", expression.this)
772
773
774def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
775    a = self.sql(expression.left)
776    b = self.sql(expression.right)
777    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
778
779
780# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
781def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
782    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
783
784
785def is_parse_json(expression: exp.Expression) -> bool:
786    return isinstance(expression, exp.ParseJSON) or (
787        isinstance(expression, exp.Cast) and expression.is_type("json")
788    )
789
790
791def isnull_to_is_null(args: t.List) -> exp.Expression:
792    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
793
794
795def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
796    if expression.expression.args.get("with"):
797        expression = expression.copy()
798        expression.set("with", expression.expression.args["with"].pop())
799    return self.insert_sql(expression)
class Dialects(builtins.str, enum.Enum):
21class Dialects(str, Enum):
22    DIALECT = ""
23
24    BIGQUERY = "bigquery"
25    CLICKHOUSE = "clickhouse"
26    DATABRICKS = "databricks"
27    DRILL = "drill"
28    DUCKDB = "duckdb"
29    HIVE = "hive"
30    MYSQL = "mysql"
31    ORACLE = "oracle"
32    POSTGRES = "postgres"
33    PRESTO = "presto"
34    REDSHIFT = "redshift"
35    SNOWFLAKE = "snowflake"
36    SPARK = "spark"
37    SPARK2 = "spark2"
38    SQLITE = "sqlite"
39    STARROCKS = "starrocks"
40    TABLEAU = "tableau"
41    TERADATA = "teradata"
42    TRINO = "trino"
43    TSQL = "tsql"
44    Doris = "doris"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Doris = <Dialects.Doris: 'doris'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
140class Dialect(metaclass=_Dialect):
141    # Determines the base index offset for arrays
142    INDEX_OFFSET = 0
143
144    # If true unnest table aliases are considered only as column aliases
145    UNNEST_COLUMN_ONLY = False
146
147    # Determines whether or not the table alias comes after tablesample
148    ALIAS_POST_TABLESAMPLE = False
149
150    # Determines whether or not unquoted identifiers are resolved as uppercase
151    # When set to None, it means that the dialect treats all identifiers as case-insensitive
152    RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
153
154    # Determines whether or not an unquoted identifier can start with a digit
155    IDENTIFIERS_CAN_START_WITH_DIGIT = False
156
157    # Determines whether or not the DPIPE token ('||') is a string concatenation operator
158    DPIPE_IS_STRING_CONCAT = True
159
160    # Determines whether or not CONCAT's arguments must be strings
161    STRICT_STRING_CONCAT = False
162
163    # Determines whether or not user-defined data types are supported
164    SUPPORTS_USER_DEFINED_TYPES = True
165
166    # Determines whether or not SEMI/ANTI JOINs are supported
167    SUPPORTS_SEMI_ANTI_JOIN = True
168
169    # Determines how function names are going to be normalized
170    NORMALIZE_FUNCTIONS: bool | str = "upper"
171
172    # Determines whether the base comes first in the LOG function
173    LOG_BASE_FIRST = True
174
175    # Indicates the default null ordering method to use if not explicitly set
176    # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
177    NULL_ORDERING = "nulls_are_small"
178
179    DATE_FORMAT = "'%Y-%m-%d'"
180    DATEINT_FORMAT = "'%Y%m%d'"
181    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
182
183    # Custom time mappings in which the key represents dialect time format
184    # and the value represents a python time format
185    TIME_MAPPING: t.Dict[str, str] = {}
186
187    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
188    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
189    # special syntax cast(x as date format 'yyyy') defaults to time_mapping
190    FORMAT_MAPPING: t.Dict[str, str] = {}
191
192    # Columns that are auto-generated by the engine corresponding to this dialect
193    # Such columns may be excluded from SELECT * queries, for example
194    PSEUDOCOLUMNS: t.Set[str] = set()
195
196    # Autofilled
197    tokenizer_class = Tokenizer
198    parser_class = Parser
199    generator_class = Generator
200
201    # A trie of the time_mapping keys
202    TIME_TRIE: t.Dict = {}
203    FORMAT_TRIE: t.Dict = {}
204
205    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
206    INVERSE_TIME_TRIE: t.Dict = {}
207
208    def __eq__(self, other: t.Any) -> bool:
209        return type(self) == other
210
211    def __hash__(self) -> int:
212        return hash(type(self))
213
214    @classmethod
215    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
216        if not dialect:
217            return cls
218        if isinstance(dialect, _Dialect):
219            return dialect
220        if isinstance(dialect, Dialect):
221            return dialect.__class__
222
223        result = cls.get(dialect)
224        if not result:
225            raise ValueError(f"Unknown dialect '{dialect}'")
226
227        return result
228
229    @classmethod
230    def format_time(
231        cls, expression: t.Optional[str | exp.Expression]
232    ) -> t.Optional[exp.Expression]:
233        if isinstance(expression, str):
234            return exp.Literal.string(
235                # the time formats are quoted
236                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
237            )
238
239        if expression and expression.is_string:
240            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
241
242        return expression
243
244    @classmethod
245    def normalize_identifier(cls, expression: E) -> E:
246        """
247        Normalizes an unquoted identifier to either lower or upper case, thus essentially
248        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
249        they will be normalized to lowercase regardless of being quoted or not.
250        """
251        if isinstance(expression, exp.Identifier) and (
252            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
253        ):
254            expression.set(
255                "this",
256                expression.this.upper()
257                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
258                else expression.this.lower(),
259            )
260
261        return expression
262
263    @classmethod
264    def case_sensitive(cls, text: str) -> bool:
265        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
266        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
267            return False
268
269        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
270        return any(unsafe(char) for char in text)
271
272    @classmethod
273    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
274        """Checks if text can be identified given an identify option.
275
276        Args:
277            text: The text to check.
278            identify:
279                "always" or `True`: Always returns true.
280                "safe": True if the identifier is case-insensitive.
281
282        Returns:
283            Whether or not the given text can be identified.
284        """
285        if identify is True or identify == "always":
286            return True
287
288        if identify == "safe":
289            return not cls.case_sensitive(text)
290
291        return False
292
293    @classmethod
294    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
295        if isinstance(expression, exp.Identifier):
296            name = expression.this
297            expression.set(
298                "quoted",
299                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
300            )
301
302        return expression
303
304    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
305        return self.parser(**opts).parse(self.tokenize(sql), sql)
306
307    def parse_into(
308        self, expression_type: exp.IntoType, sql: str, **opts
309    ) -> t.List[t.Optional[exp.Expression]]:
310        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
311
312    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
313        return self.generator(**opts).generate(expression)
314
315    def transpile(self, sql: str, **opts) -> t.List[str]:
316        return [self.generate(expression, **opts) for expression in self.parse(sql)]
317
318    def tokenize(self, sql: str) -> t.List[Token]:
319        return self.tokenizer.tokenize(sql)
320
321    @property
322    def tokenizer(self) -> Tokenizer:
323        if not hasattr(self, "_tokenizer"):
324            self._tokenizer = self.tokenizer_class()
325        return self._tokenizer
326
327    def parser(self, **opts) -> Parser:
328        return self.parser_class(**opts)
329
330    def generator(self, **opts) -> Generator:
331        return self.generator_class(**opts)
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
RESOLVES_IDENTIFIERS_AS_UPPERCASE: Optional[bool] = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
DPIPE_IS_STRING_CONCAT = True
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
SUPPORTS_SEMI_ANTI_JOIN = True
NORMALIZE_FUNCTIONS: bool | str = 'upper'
LOG_BASE_FIRST = True
NULL_ORDERING = 'nulls_are_small'
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}
FORMAT_MAPPING: Dict[str, str] = {}
PSEUDOCOLUMNS: Set[str] = set()
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Type[Dialect]:
214    @classmethod
215    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
216        if not dialect:
217            return cls
218        if isinstance(dialect, _Dialect):
219            return dialect
220        if isinstance(dialect, Dialect):
221            return dialect.__class__
222
223        result = cls.get(dialect)
224        if not result:
225            raise ValueError(f"Unknown dialect '{dialect}'")
226
227        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
229    @classmethod
230    def format_time(
231        cls, expression: t.Optional[str | exp.Expression]
232    ) -> t.Optional[exp.Expression]:
233        if isinstance(expression, str):
234            return exp.Literal.string(
235                # the time formats are quoted
236                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
237            )
238
239        if expression and expression.is_string:
240            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
241
242        return expression
@classmethod
def normalize_identifier(cls, expression: ~E) -> ~E:
244    @classmethod
245    def normalize_identifier(cls, expression: E) -> E:
246        """
247        Normalizes an unquoted identifier to either lower or upper case, thus essentially
248        making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
249        they will be normalized to lowercase regardless of being quoted or not.
250        """
251        if isinstance(expression, exp.Identifier) and (
252            not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
253        ):
254            expression.set(
255                "this",
256                expression.this.upper()
257                if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
258                else expression.this.lower(),
259            )
260
261        return expression

Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.

@classmethod
def case_sensitive(cls, text: str) -> bool:
263    @classmethod
264    def case_sensitive(cls, text: str) -> bool:
265        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
266        if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
267            return False
268
269        unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
270        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

@classmethod
def can_identify(cls, text: str, identify: str | bool = 'safe') -> bool:
272    @classmethod
273    def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
274        """Checks if text can be identified given an identify option.
275
276        Args:
277            text: The text to check.
278            identify:
279                "always" or `True`: Always returns true.
280                "safe": True if the identifier is case-insensitive.
281
282        Returns:
283            Whether or not the given text can be identified.
284        """
285        if identify is True or identify == "always":
286            return True
287
288        if identify == "safe":
289            return not cls.case_sensitive(text)
290
291        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

@classmethod
def quote_identifier(cls, expression: ~E, identify: bool = True) -> ~E:
293    @classmethod
294    def quote_identifier(cls, expression: E, identify: bool = True) -> E:
295        if isinstance(expression, exp.Identifier):
296            name = expression.this
297            expression.set(
298                "quoted",
299                identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
300            )
301
302        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
304    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
305        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
307    def parse_into(
308        self, expression_type: exp.IntoType, sql: str, **opts
309    ) -> t.List[t.Optional[exp.Expression]]:
310        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
312    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
313        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
315    def transpile(self, sql: str, **opts) -> t.List[str]:
316        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
318    def tokenize(self, sql: str) -> t.List[Token]:
319        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
327    def parser(self, **opts) -> Parser:
328        return self.parser_class(**opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
330    def generator(self, **opts) -> Generator:
331        return self.generator_class(**opts)
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START = None
BIT_END = None
HEX_START = None
HEX_END = None
BYTE_START = None
BYTE_END = None
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
337def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
338    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
341def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
342    if expression.args.get("accuracy"):
343        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
344    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
347def if_sql(
348    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
349) -> t.Callable[[Generator, exp.If], str]:
350    def _if_sql(self: Generator, expression: exp.If) -> str:
351        return self.func(
352            name,
353            expression.this,
354            expression.args.get("true"),
355            expression.args.get("false") or false_value,
356        )
357
358    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
361def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
362    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
365def arrow_json_extract_scalar_sql(
366    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
367) -> str:
368    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
371def inline_array_sql(self: Generator, expression: exp.Array) -> str:
372    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
375def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
376    return self.like_sql(
377        exp.Like(
378            this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
379        )
380    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
383def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
384    zone = self.sql(expression, "this")
385    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
388def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
389    if expression.args.get("recursive"):
390        self.unsupported("Recursive CTEs are unsupported")
391        expression.args["recursive"] = False
392    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
395def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
396    n = self.sql(expression, "this")
397    d = self.sql(expression, "expression")
398    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
401def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
402    self.unsupported("TABLESAMPLE unsupported")
403    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
406def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
407    self.unsupported("PIVOT unsupported")
408    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
411def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
412    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
415def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
416    self.unsupported("Properties unsupported")
417    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
420def no_comment_column_constraint_sql(
421    self: Generator, expression: exp.CommentColumnConstraint
422) -> str:
423    self.unsupported("CommentColumnConstraint unsupported")
424    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
427def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
428    self.unsupported("MAP_FROM_ENTRIES unsupported")
429    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
432def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
433    this = self.sql(expression, "this")
434    substr = self.sql(expression, "substr")
435    position = self.sql(expression, "position")
436    if position:
437        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
438    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
441def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
442    return (
443        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
444    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
447def var_map_sql(
448    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
449) -> str:
450    keys = expression.args["keys"]
451    values = expression.args["values"]
452
453    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
454        self.unsupported("Cannot convert array columns into map.")
455        return self.func(map_func_name, keys, values)
456
457    args = []
458    for key, value in zip(keys.expressions, values.expressions):
459        args.append(self.sql(key))
460        args.append(self.sql(value))
461
462    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
465def format_time_lambda(
466    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
467) -> t.Callable[[t.List], E]:
468    """Helper used for time expressions.
469
470    Args:
471        exp_class: the expression class to instantiate.
472        dialect: target sql dialect.
473        default: the default format, True being time.
474
475    Returns:
476        A callable that can be used to return the appropriately formatted time expression.
477    """
478
479    def _format_time(args: t.List):
480        return exp_class(
481            this=seq_get(args, 0),
482            format=Dialect[dialect].format_time(
483                seq_get(args, 1)
484                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
485            ),
486        )
487
488    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
491def time_format(
492    dialect: DialectType = None,
493) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
494    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
495        """
496        Returns the time format for a given expression, unless it's equivalent
497        to the default time format of the dialect of interest.
498        """
499        time_format = self.format_time(expression)
500        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
501
502    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
505def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
506    """
507    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
508    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
509    columns are removed from the create statement.
510    """
511    has_schema = isinstance(expression.this, exp.Schema)
512    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
513
514    if has_schema and is_partitionable:
515        expression = expression.copy()
516        prop = expression.find(exp.PartitionedByProperty)
517        if prop and prop.this and not isinstance(prop.this, exp.Schema):
518            schema = expression.this
519            columns = {v.name.upper() for v in prop.this.expressions}
520            partitions = [col for col in schema.expressions if col.name.upper() in columns]
521            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
522            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
523            expression.set("this", schema)
524
525    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
528def parse_date_delta(
529    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
530) -> t.Callable[[t.List], E]:
531    def inner_func(args: t.List) -> E:
532        unit_based = len(args) == 3
533        this = args[2] if unit_based else seq_get(args, 0)
534        unit = args[0] if unit_based else exp.Literal.string("DAY")
535        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
536        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
537
538    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
541def parse_date_delta_with_interval(
542    expression_class: t.Type[E],
543) -> t.Callable[[t.List], t.Optional[E]]:
544    def func(args: t.List) -> t.Optional[E]:
545        if len(args) < 2:
546            return None
547
548        interval = args[1]
549
550        if not isinstance(interval, exp.Interval):
551            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
552
553        expression = interval.this
554        if expression and expression.is_string:
555            expression = exp.Literal.number(expression.this)
556
557        return expression_class(
558            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
559        )
560
561    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
564def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
565    unit = seq_get(args, 0)
566    this = seq_get(args, 1)
567
568    if isinstance(this, exp.Cast) and this.is_type("date"):
569        return exp.DateTrunc(unit=unit, this=this)
570    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
573def date_add_interval_sql(
574    data_type: str, kind: str
575) -> t.Callable[[Generator, exp.Expression], str]:
576    def func(self: Generator, expression: exp.Expression) -> str:
577        this = self.sql(expression, "this")
578        unit = expression.args.get("unit")
579        unit = exp.var(unit.name.upper() if unit else "DAY")
580        interval = exp.Interval(this=expression.expression.copy(), unit=unit)
581        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
582
583    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
586def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
587    return self.func(
588        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
589    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
592def locate_to_strposition(args: t.List) -> exp.Expression:
593    return exp.StrPosition(
594        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
595    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
598def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
599    return self.func(
600        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
601    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
604def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
605    expression = expression.copy()
606    return self.sql(
607        exp.Substring(
608            this=expression.this, start=exp.Literal.number(1), length=expression.expression
609        )
610    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
613def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
614    expression = expression.copy()
615    return self.sql(
616        exp.Substring(
617            this=expression.this,
618            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
619        )
620    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
623def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
624    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
627def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
628    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
632def encode_decode_sql(
633    self: Generator, expression: exp.Expression, name: str, replace: bool = True
634) -> str:
635    charset = expression.args.get("charset")
636    if charset and charset.name.lower() != "utf-8":
637        self.unsupported(f"Expected utf-8 character set, got {charset}.")
638
639    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
642def min_or_least(self: Generator, expression: exp.Min) -> str:
643    name = "LEAST" if expression.expressions else "MIN"
644    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
647def max_or_greatest(self: Generator, expression: exp.Max) -> str:
648    name = "GREATEST" if expression.expressions else "MAX"
649    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
652def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
653    cond = expression.this
654
655    if isinstance(expression.this, exp.Distinct):
656        cond = expression.this.expressions[0]
657        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
658
659    return self.func("sum", exp.func("if", cond.copy(), 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
662def trim_sql(self: Generator, expression: exp.Trim) -> str:
663    target = self.sql(expression, "this")
664    trim_type = self.sql(expression, "position")
665    remove_chars = self.sql(expression, "expression")
666    collation = self.sql(expression, "collation")
667
668    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
669    if not remove_chars and not collation:
670        return self.trim_sql(expression)
671
672    trim_type = f"{trim_type} " if trim_type else ""
673    remove_chars = f"{remove_chars} " if remove_chars else ""
674    from_part = "FROM " if trim_type or remove_chars else ""
675    collation = f" COLLATE {collation}" if collation else ""
676    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
679def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
680    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
683def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
684    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
685        _dialect = Dialect.get_or_raise(dialect)
686        time_format = self.format_time(expression)
687        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
688            return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
689
690        return self.sql(exp.cast(self.sql(expression, "this"), "date"))
691
692    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat | sqlglot.expressions.SafeConcat) -> str:
695def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
696    expression = expression.copy()
697    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
700def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
701    expression = expression.copy()
702    delim, *rest_args = expression.expressions
703    return self.sql(
704        reduce(
705            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
706            rest_args,
707        )
708    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
711def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
712    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
713    if bad_args:
714        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
715
716    return self.func(
717        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
718    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
721def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
722    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
723    if bad_args:
724        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
725
726    return self.func(
727        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
728    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
731def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
732    names = []
733    for agg in aggregations:
734        if isinstance(agg, exp.Alias):
735            names.append(agg.alias)
736        else:
737            """
738            This case corresponds to aggregations without aliases being used as suffixes
739            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
740            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
741            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
742            """
743            agg_all_unquoted = agg.transform(
744                lambda node: exp.Identifier(this=node.name, quoted=False)
745                if isinstance(node, exp.Identifier)
746                else node
747            )
748            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
749
750    return names
def simplify_literal(expression: ~E) -> ~E:
753def simplify_literal(expression: E) -> E:
754    if not isinstance(expression.expression, exp.Literal):
755        from sqlglot.optimizer.simplify import simplify
756
757        simplify(expression.expression)
758
759    return expression
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
762def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
763    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
767def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
768    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
771def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
772    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
775def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
776    a = self.sql(expression.left)
777    b = self.sql(expression.right)
778    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
782def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
783    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
786def is_parse_json(expression: exp.Expression) -> bool:
787    return isinstance(expression, exp.ParseJSON) or (
788        isinstance(expression, exp.Cast) and expression.is_type("json")
789    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
792def isnull_to_is_null(args: t.List) -> exp.Expression:
793    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def move_insert_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Insert) -> str:
796def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
797    if expression.expression.args.get("with"):
798        expression = expression.copy()
799        expression.set("with", expression.expression.args["with"].pop())
800    return self.insert_sql(expression)