Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum, auto
  5from functools import reduce
  6
  7from sqlglot import exp
  8from sqlglot._typing import E
  9from sqlglot.errors import ParseError
 10from sqlglot.generator import Generator
 11from sqlglot.helper import AutoName, flatten, seq_get
 12from sqlglot.parser import Parser
 13from sqlglot.time import TIMEZONES, format_time
 14from sqlglot.tokens import Token, Tokenizer, TokenType
 15from sqlglot.trie import new_trie
 16
 17B = t.TypeVar("B", bound=exp.Binary)
 18
 19DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
 20DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
 21
 22
 23class Dialects(str, Enum):
 24    """Dialects supported by SQLGLot."""
 25
 26    DIALECT = ""
 27
 28    BIGQUERY = "bigquery"
 29    CLICKHOUSE = "clickhouse"
 30    DATABRICKS = "databricks"
 31    DORIS = "doris"
 32    DRILL = "drill"
 33    DUCKDB = "duckdb"
 34    HIVE = "hive"
 35    MYSQL = "mysql"
 36    ORACLE = "oracle"
 37    POSTGRES = "postgres"
 38    PRESTO = "presto"
 39    REDSHIFT = "redshift"
 40    SNOWFLAKE = "snowflake"
 41    SPARK = "spark"
 42    SPARK2 = "spark2"
 43    SQLITE = "sqlite"
 44    STARROCKS = "starrocks"
 45    TABLEAU = "tableau"
 46    TERADATA = "teradata"
 47    TRINO = "trino"
 48    TSQL = "tsql"
 49
 50
 51class NormalizationStrategy(str, AutoName):
 52    """Specifies the strategy according to which identifiers should be normalized."""
 53
 54    LOWERCASE = auto()
 55    """Unquoted identifiers are lowercased."""
 56
 57    UPPERCASE = auto()
 58    """Unquoted identifiers are uppercased."""
 59
 60    CASE_SENSITIVE = auto()
 61    """Always case-sensitive, regardless of quotes."""
 62
 63    CASE_INSENSITIVE = auto()
 64    """Always case-insensitive, regardless of quotes."""
 65
 66
 67class _Dialect(type):
 68    classes: t.Dict[str, t.Type[Dialect]] = {}
 69
 70    def __eq__(cls, other: t.Any) -> bool:
 71        if cls is other:
 72            return True
 73        if isinstance(other, str):
 74            return cls is cls.get(other)
 75        if isinstance(other, Dialect):
 76            return cls is type(other)
 77
 78        return False
 79
 80    def __hash__(cls) -> int:
 81        return hash(cls.__name__.lower())
 82
 83    @classmethod
 84    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 85        return cls.classes[key]
 86
 87    @classmethod
 88    def get(
 89        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 90    ) -> t.Optional[t.Type[Dialect]]:
 91        return cls.classes.get(key, default)
 92
 93    def __new__(cls, clsname, bases, attrs):
 94        klass = super().__new__(cls, clsname, bases, attrs)
 95        enum = Dialects.__members__.get(clsname.upper())
 96        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 97
 98        klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
 99        klass.FORMAT_TRIE = (
100            new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
101        )
102        klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
103        klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
104
105        klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
106
107        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
108        klass.parser_class = getattr(klass, "Parser", Parser)
109        klass.generator_class = getattr(klass, "Generator", Generator)
110
111        klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
112        klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
113            klass.tokenizer_class._IDENTIFIERS.items()
114        )[0]
115
116        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
117            return next(
118                (
119                    (s, e)
120                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
121                    if t == token_type
122                ),
123                (None, None),
124            )
125
126        klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
127        klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
128        klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
129        klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
130
131        if enum not in ("", "bigquery"):
132            klass.generator_class.SELECT_KINDS = ()
133
134        if not klass.SUPPORTS_SEMI_ANTI_JOIN:
135            klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
136                TokenType.ANTI,
137                TokenType.SEMI,
138            }
139
140        return klass
141
142
143class Dialect(metaclass=_Dialect):
144    INDEX_OFFSET = 0
145    """Determines the base index offset for arrays."""
146
147    WEEK_OFFSET = 0
148    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
149
150    UNNEST_COLUMN_ONLY = False
151    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
152
153    ALIAS_POST_TABLESAMPLE = False
154    """Determines whether or not the table alias comes after tablesample."""
155
156    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
157    """Specifies the strategy according to which identifiers should be normalized."""
158
159    IDENTIFIERS_CAN_START_WITH_DIGIT = False
160    """Determines whether or not an unquoted identifier can start with a digit."""
161
162    DPIPE_IS_STRING_CONCAT = True
163    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
164
165    STRICT_STRING_CONCAT = False
166    """Determines whether or not `CONCAT`'s arguments must be strings."""
167
168    SUPPORTS_USER_DEFINED_TYPES = True
169    """Determines whether or not user-defined data types are supported."""
170
171    SUPPORTS_SEMI_ANTI_JOIN = True
172    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
173
174    NORMALIZE_FUNCTIONS: bool | str = "upper"
175    """Determines how function names are going to be normalized."""
176
177    LOG_BASE_FIRST = True
178    """Determines whether the base comes first in the `LOG` function."""
179
180    NULL_ORDERING = "nulls_are_small"
181    """
182    Indicates the default `NULL` ordering method to use if not explicitly set.
183    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
184    """
185
186    TYPED_DIVISION = False
187    """
188    Whether the behavior of `a / b` depends on the types of `a` and `b`.
189    False means `a / b` is always float division.
190    True means `a / b` is integer division if both `a` and `b` are integers.
191    """
192
193    SAFE_DIVISION = False
194    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
195
196    CONCAT_COALESCE = False
197    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
198
199    DATE_FORMAT = "'%Y-%m-%d'"
200    DATEINT_FORMAT = "'%Y%m%d'"
201    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
202
203    TIME_MAPPING: t.Dict[str, str] = {}
204    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
205
206    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
207    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
208    FORMAT_MAPPING: t.Dict[str, str] = {}
209    """
210    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
211    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
212    """
213
214    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
215    """Mapping of an unescaped escape sequence to the corresponding character."""
216
217    PSEUDOCOLUMNS: t.Set[str] = set()
218    """
219    Columns that are auto-generated by the engine corresponding to this dialect.
220    For example, such columns may be excluded from `SELECT *` queries.
221    """
222
223    PREFER_CTE_ALIAS_COLUMN = False
224    """
225    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
226    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
227    any projection aliases in the subquery.
228
229    For example,
230        WITH y(c) AS (
231            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
232        ) SELECT c FROM y;
233
234        will be rewritten as
235
236        WITH y(c) AS (
237            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
238        ) SELECT c FROM y;
239    """
240
241    # --- Autofilled ---
242
243    tokenizer_class = Tokenizer
244    parser_class = Parser
245    generator_class = Generator
246
247    # A trie of the time_mapping keys
248    TIME_TRIE: t.Dict = {}
249    FORMAT_TRIE: t.Dict = {}
250
251    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
252    INVERSE_TIME_TRIE: t.Dict = {}
253
254    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
255
256    # Delimiters for quotes, identifiers and the corresponding escape characters
257    QUOTE_START = "'"
258    QUOTE_END = "'"
259    IDENTIFIER_START = '"'
260    IDENTIFIER_END = '"'
261
262    # Delimiters for bit, hex, byte and unicode literals
263    BIT_START: t.Optional[str] = None
264    BIT_END: t.Optional[str] = None
265    HEX_START: t.Optional[str] = None
266    HEX_END: t.Optional[str] = None
267    BYTE_START: t.Optional[str] = None
268    BYTE_END: t.Optional[str] = None
269    UNICODE_START: t.Optional[str] = None
270    UNICODE_END: t.Optional[str] = None
271
272    @classmethod
273    def get_or_raise(cls, dialect: DialectType) -> Dialect:
274        """
275        Look up a dialect in the global dialect registry and return it if it exists.
276
277        Args:
278            dialect: The target dialect. If this is a string, it can be optionally followed by
279                additional key-value pairs that are separated by commas and are used to specify
280                dialect settings, such as whether the dialect's identifiers are case-sensitive.
281
282        Example:
283            >>> dialect = dialect_class = get_or_raise("duckdb")
284            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
285
286        Returns:
287            The corresponding Dialect instance.
288        """
289
290        if not dialect:
291            return cls()
292        if isinstance(dialect, _Dialect):
293            return dialect()
294        if isinstance(dialect, Dialect):
295            return dialect
296        if isinstance(dialect, str):
297            try:
298                dialect_name, *kv_pairs = dialect.split(",")
299                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
300            except ValueError:
301                raise ValueError(
302                    f"Invalid dialect format: '{dialect}'. "
303                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
304                )
305
306            result = cls.get(dialect_name.strip())
307            if not result:
308                raise ValueError(f"Unknown dialect '{dialect_name}'.")
309
310            return result(**kwargs)
311
312        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
313
314    @classmethod
315    def format_time(
316        cls, expression: t.Optional[str | exp.Expression]
317    ) -> t.Optional[exp.Expression]:
318        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
319        if isinstance(expression, str):
320            return exp.Literal.string(
321                # the time formats are quoted
322                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
323            )
324
325        if expression and expression.is_string:
326            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
327
328        return expression
329
330    def __init__(self, **kwargs) -> None:
331        normalization_strategy = kwargs.get("normalization_strategy")
332
333        if normalization_strategy is None:
334            self.normalization_strategy = self.NORMALIZATION_STRATEGY
335        else:
336            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
337
338    def __eq__(self, other: t.Any) -> bool:
339        # Does not currently take dialect state into account
340        return type(self) == other
341
342    def __hash__(self) -> int:
343        # Does not currently take dialect state into account
344        return hash(type(self))
345
346    def normalize_identifier(self, expression: E) -> E:
347        """
348        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
349
350        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
351        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
352        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
353        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
354
355        There are also dialects like Spark, which are case-insensitive even when quotes are
356        present, and dialects like MySQL, whose resolution rules match those employed by the
357        underlying operating system, for example they may always be case-sensitive in Linux.
358
359        Finally, the normalization behavior of some engines can even be controlled through flags,
360        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
361
362        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
363        that it can analyze queries in the optimizer and successfully capture their semantics.
364        """
365        if (
366            isinstance(expression, exp.Identifier)
367            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
368            and (
369                not expression.quoted
370                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
371            )
372        ):
373            expression.set(
374                "this",
375                expression.this.upper()
376                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
377                else expression.this.lower(),
378            )
379
380        return expression
381
382    def case_sensitive(self, text: str) -> bool:
383        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
384        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
385            return False
386
387        unsafe = (
388            str.islower
389            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
390            else str.isupper
391        )
392        return any(unsafe(char) for char in text)
393
394    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
395        """Checks if text can be identified given an identify option.
396
397        Args:
398            text: The text to check.
399            identify:
400                `"always"` or `True`: Always returns `True`.
401                `"safe"`: Only returns `True` if the identifier is case-insensitive.
402
403        Returns:
404            Whether or not the given text can be identified.
405        """
406        if identify is True or identify == "always":
407            return True
408
409        if identify == "safe":
410            return not self.case_sensitive(text)
411
412        return False
413
414    def quote_identifier(self, expression: E, identify: bool = True) -> E:
415        """
416        Adds quotes to a given identifier.
417
418        Args:
419            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
420            identify: If set to `False`, the quotes will only be added if the identifier is deemed
421                "unsafe", with respect to its characters and this dialect's normalization strategy.
422        """
423        if isinstance(expression, exp.Identifier):
424            name = expression.this
425            expression.set(
426                "quoted",
427                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
428            )
429
430        return expression
431
432    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
433        return self.parser(**opts).parse(self.tokenize(sql), sql)
434
435    def parse_into(
436        self, expression_type: exp.IntoType, sql: str, **opts
437    ) -> t.List[t.Optional[exp.Expression]]:
438        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
439
440    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
441        return self.generator(**opts).generate(expression, copy=copy)
442
443    def transpile(self, sql: str, **opts) -> t.List[str]:
444        return [
445            self.generate(expression, copy=False, **opts) if expression else ""
446            for expression in self.parse(sql)
447        ]
448
449    def tokenize(self, sql: str) -> t.List[Token]:
450        return self.tokenizer.tokenize(sql)
451
452    @property
453    def tokenizer(self) -> Tokenizer:
454        if not hasattr(self, "_tokenizer"):
455            self._tokenizer = self.tokenizer_class(dialect=self)
456        return self._tokenizer
457
458    def parser(self, **opts) -> Parser:
459        return self.parser_class(dialect=self, **opts)
460
461    def generator(self, **opts) -> Generator:
462        return self.generator_class(dialect=self, **opts)
463
464
465DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
466
467
468def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
469    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
470
471
472def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
473    if expression.args.get("accuracy"):
474        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
475    return self.func("APPROX_COUNT_DISTINCT", expression.this)
476
477
478def if_sql(
479    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
480) -> t.Callable[[Generator, exp.If], str]:
481    def _if_sql(self: Generator, expression: exp.If) -> str:
482        return self.func(
483            name,
484            expression.this,
485            expression.args.get("true"),
486            expression.args.get("false") or false_value,
487        )
488
489    return _if_sql
490
491
492def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
493    return self.binary(expression, "->")
494
495
496def arrow_json_extract_scalar_sql(
497    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
498) -> str:
499    return self.binary(expression, "->>")
500
501
502def inline_array_sql(self: Generator, expression: exp.Array) -> str:
503    return f"[{self.expressions(expression, flat=True)}]"
504
505
506def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
507    return self.like_sql(
508        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
509    )
510
511
512def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
513    zone = self.sql(expression, "this")
514    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
515
516
517def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
518    if expression.args.get("recursive"):
519        self.unsupported("Recursive CTEs are unsupported")
520        expression.args["recursive"] = False
521    return self.with_sql(expression)
522
523
524def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
525    n = self.sql(expression, "this")
526    d = self.sql(expression, "expression")
527    return f"IF({d} <> 0, {n} / {d}, NULL)"
528
529
530def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
531    self.unsupported("TABLESAMPLE unsupported")
532    return self.sql(expression.this)
533
534
535def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
536    self.unsupported("PIVOT unsupported")
537    return ""
538
539
540def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
541    return self.cast_sql(expression)
542
543
544def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
545    self.unsupported("Properties unsupported")
546    return ""
547
548
549def no_comment_column_constraint_sql(
550    self: Generator, expression: exp.CommentColumnConstraint
551) -> str:
552    self.unsupported("CommentColumnConstraint unsupported")
553    return ""
554
555
556def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
557    self.unsupported("MAP_FROM_ENTRIES unsupported")
558    return ""
559
560
561def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
562    this = self.sql(expression, "this")
563    substr = self.sql(expression, "substr")
564    position = self.sql(expression, "position")
565    if position:
566        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
567    return f"STRPOS({this}, {substr})"
568
569
570def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
571    return (
572        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
573    )
574
575
576def var_map_sql(
577    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
578) -> str:
579    keys = expression.args["keys"]
580    values = expression.args["values"]
581
582    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
583        self.unsupported("Cannot convert array columns into map.")
584        return self.func(map_func_name, keys, values)
585
586    args = []
587    for key, value in zip(keys.expressions, values.expressions):
588        args.append(self.sql(key))
589        args.append(self.sql(value))
590
591    return self.func(map_func_name, *args)
592
593
594def format_time_lambda(
595    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
596) -> t.Callable[[t.List], E]:
597    """Helper used for time expressions.
598
599    Args:
600        exp_class: the expression class to instantiate.
601        dialect: target sql dialect.
602        default: the default format, True being time.
603
604    Returns:
605        A callable that can be used to return the appropriately formatted time expression.
606    """
607
608    def _format_time(args: t.List):
609        return exp_class(
610            this=seq_get(args, 0),
611            format=Dialect[dialect].format_time(
612                seq_get(args, 1)
613                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
614            ),
615        )
616
617    return _format_time
618
619
620def time_format(
621    dialect: DialectType = None,
622) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
623    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
624        """
625        Returns the time format for a given expression, unless it's equivalent
626        to the default time format of the dialect of interest.
627        """
628        time_format = self.format_time(expression)
629        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
630
631    return _time_format
632
633
634def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
635    """
636    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
637    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
638    columns are removed from the create statement.
639    """
640    has_schema = isinstance(expression.this, exp.Schema)
641    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
642
643    if has_schema and is_partitionable:
644        prop = expression.find(exp.PartitionedByProperty)
645        if prop and prop.this and not isinstance(prop.this, exp.Schema):
646            schema = expression.this
647            columns = {v.name.upper() for v in prop.this.expressions}
648            partitions = [col for col in schema.expressions if col.name.upper() in columns]
649            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
650            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
651            expression.set("this", schema)
652
653    return self.create_sql(expression)
654
655
656def parse_date_delta(
657    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
658) -> t.Callable[[t.List], E]:
659    def inner_func(args: t.List) -> E:
660        unit_based = len(args) == 3
661        this = args[2] if unit_based else seq_get(args, 0)
662        unit = args[0] if unit_based else exp.Literal.string("DAY")
663        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
664        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
665
666    return inner_func
667
668
669def parse_date_delta_with_interval(
670    expression_class: t.Type[E],
671) -> t.Callable[[t.List], t.Optional[E]]:
672    def func(args: t.List) -> t.Optional[E]:
673        if len(args) < 2:
674            return None
675
676        interval = args[1]
677
678        if not isinstance(interval, exp.Interval):
679            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
680
681        expression = interval.this
682        if expression and expression.is_string:
683            expression = exp.Literal.number(expression.this)
684
685        return expression_class(
686            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
687        )
688
689    return func
690
691
692def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
693    unit = seq_get(args, 0)
694    this = seq_get(args, 1)
695
696    if isinstance(this, exp.Cast) and this.is_type("date"):
697        return exp.DateTrunc(unit=unit, this=this)
698    return exp.TimestampTrunc(this=this, unit=unit)
699
700
701def date_add_interval_sql(
702    data_type: str, kind: str
703) -> t.Callable[[Generator, exp.Expression], str]:
704    def func(self: Generator, expression: exp.Expression) -> str:
705        this = self.sql(expression, "this")
706        unit = expression.args.get("unit")
707        unit = exp.var(unit.name.upper() if unit else "DAY")
708        interval = exp.Interval(this=expression.expression, unit=unit)
709        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
710
711    return func
712
713
714def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
715    return self.func(
716        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
717    )
718
719
720def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
721    if not expression.expression:
722        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
723    if expression.text("expression").lower() in TIMEZONES:
724        return self.sql(
725            exp.AtTimeZone(
726                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
727                zone=expression.expression,
728            )
729        )
730    return self.function_fallback_sql(expression)
731
732
733def locate_to_strposition(args: t.List) -> exp.Expression:
734    return exp.StrPosition(
735        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
736    )
737
738
739def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
740    return self.func(
741        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
742    )
743
744
745def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
746    return self.sql(
747        exp.Substring(
748            this=expression.this, start=exp.Literal.number(1), length=expression.expression
749        )
750    )
751
752
753def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
754    return self.sql(
755        exp.Substring(
756            this=expression.this,
757            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
758        )
759    )
760
761
762def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
763    return self.sql(exp.cast(expression.this, "timestamp"))
764
765
766def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
767    return self.sql(exp.cast(expression.this, "date"))
768
769
770# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
771def encode_decode_sql(
772    self: Generator, expression: exp.Expression, name: str, replace: bool = True
773) -> str:
774    charset = expression.args.get("charset")
775    if charset and charset.name.lower() != "utf-8":
776        self.unsupported(f"Expected utf-8 character set, got {charset}.")
777
778    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
779
780
781def min_or_least(self: Generator, expression: exp.Min) -> str:
782    name = "LEAST" if expression.expressions else "MIN"
783    return rename_func(name)(self, expression)
784
785
786def max_or_greatest(self: Generator, expression: exp.Max) -> str:
787    name = "GREATEST" if expression.expressions else "MAX"
788    return rename_func(name)(self, expression)
789
790
791def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
792    cond = expression.this
793
794    if isinstance(expression.this, exp.Distinct):
795        cond = expression.this.expressions[0]
796        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
797
798    return self.func("sum", exp.func("if", cond, 1, 0))
799
800
801def trim_sql(self: Generator, expression: exp.Trim) -> str:
802    target = self.sql(expression, "this")
803    trim_type = self.sql(expression, "position")
804    remove_chars = self.sql(expression, "expression")
805    collation = self.sql(expression, "collation")
806
807    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
808    if not remove_chars and not collation:
809        return self.trim_sql(expression)
810
811    trim_type = f"{trim_type} " if trim_type else ""
812    remove_chars = f"{remove_chars} " if remove_chars else ""
813    from_part = "FROM " if trim_type or remove_chars else ""
814    collation = f" COLLATE {collation}" if collation else ""
815    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
816
817
818def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
819    return self.func("STRPTIME", expression.this, self.format_time(expression))
820
821
822def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
823    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
824        _dialect = Dialect.get_or_raise(dialect)
825        time_format = self.format_time(expression)
826        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
827            return self.sql(
828                exp.cast(
829                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
830                    "date",
831                )
832            )
833        return self.sql(exp.cast(expression.this, "date"))
834
835    return _ts_or_ds_to_date_sql
836
837
838def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
839    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
840
841
842def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
843    delim, *rest_args = expression.expressions
844    return self.sql(
845        reduce(
846            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
847            rest_args,
848        )
849    )
850
851
852def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
853    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
854    if bad_args:
855        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
856
857    return self.func(
858        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
859    )
860
861
862def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
863    bad_args = list(
864        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
865    )
866    if bad_args:
867        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
868
869    return self.func(
870        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
871    )
872
873
874def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
875    names = []
876    for agg in aggregations:
877        if isinstance(agg, exp.Alias):
878            names.append(agg.alias)
879        else:
880            """
881            This case corresponds to aggregations without aliases being used as suffixes
882            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
883            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
884            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
885            """
886            agg_all_unquoted = agg.transform(
887                lambda node: exp.Identifier(this=node.name, quoted=False)
888                if isinstance(node, exp.Identifier)
889                else node
890            )
891            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
892
893    return names
894
895
896def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
897    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
898
899
900# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
901def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
902    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
903
904
905def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
906    return self.func("MAX", expression.this)
907
908
909def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
910    a = self.sql(expression.left)
911    b = self.sql(expression.right)
912    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
913
914
915# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
916def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
917    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
918
919
920def is_parse_json(expression: exp.Expression) -> bool:
921    return isinstance(expression, exp.ParseJSON) or (
922        isinstance(expression, exp.Cast) and expression.is_type("json")
923    )
924
925
926def isnull_to_is_null(args: t.List) -> exp.Expression:
927    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
928
929
930def generatedasidentitycolumnconstraint_sql(
931    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
932) -> str:
933    start = self.sql(expression, "start") or "1"
934    increment = self.sql(expression, "increment") or "1"
935    return f"IDENTITY({start}, {increment})"
936
937
938def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
939    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
940        if expression.args.get("count"):
941            self.unsupported(f"Only two arguments are supported in function {name}.")
942
943        return self.func(name, expression.this, expression.expression)
944
945    return _arg_max_or_min_sql
946
947
948def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
949    this = expression.this.copy()
950
951    return_type = expression.return_type
952    if return_type.is_type(exp.DataType.Type.DATE):
953        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
954        # can truncate timestamp strings, because some dialects can't cast them to DATE
955        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
956
957    expression.this.replace(exp.cast(this, return_type))
958    return expression
959
960
961def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
962    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
963        if cast and isinstance(expression, exp.TsOrDsAdd):
964            expression = ts_or_ds_add_cast(expression)
965
966        return self.func(
967            name,
968            exp.var(expression.text("unit").upper() or "DAY"),
969            expression.expression,
970            expression.this,
971        )
972
973    return _delta_sql
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    """Dialects supported by SQLGLot."""
26
27    DIALECT = ""
28
29    BIGQUERY = "bigquery"
30    CLICKHOUSE = "clickhouse"
31    DATABRICKS = "databricks"
32    DORIS = "doris"
33    DRILL = "drill"
34    DUCKDB = "duckdb"
35    HIVE = "hive"
36    MYSQL = "mysql"
37    ORACLE = "oracle"
38    POSTGRES = "postgres"
39    PRESTO = "presto"
40    REDSHIFT = "redshift"
41    SNOWFLAKE = "snowflake"
42    SPARK = "spark"
43    SPARK2 = "spark2"
44    SQLITE = "sqlite"
45    STARROCKS = "starrocks"
46    TABLEAU = "tableau"
47    TERADATA = "teradata"
48    TRINO = "trino"
49    TSQL = "tsql"

Dialects supported by SQLGLot.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DORIS = <Dialects.DORIS: 'doris'>
DRILL = <Dialects.DRILL: 'drill'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TERADATA = <Dialects.TERADATA: 'teradata'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class NormalizationStrategy(builtins.str, sqlglot.helper.AutoName):
52class NormalizationStrategy(str, AutoName):
53    """Specifies the strategy according to which identifiers should be normalized."""
54
55    LOWERCASE = auto()
56    """Unquoted identifiers are lowercased."""
57
58    UPPERCASE = auto()
59    """Unquoted identifiers are uppercased."""
60
61    CASE_SENSITIVE = auto()
62    """Always case-sensitive, regardless of quotes."""
63
64    CASE_INSENSITIVE = auto()
65    """Always case-insensitive, regardless of quotes."""

Specifies the strategy according to which identifiers should be normalized.

LOWERCASE = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Unquoted identifiers are lowercased.

UPPERCASE = <NormalizationStrategy.UPPERCASE: 'UPPERCASE'>

Unquoted identifiers are uppercased.

CASE_SENSITIVE = <NormalizationStrategy.CASE_SENSITIVE: 'CASE_SENSITIVE'>

Always case-sensitive, regardless of quotes.

CASE_INSENSITIVE = <NormalizationStrategy.CASE_INSENSITIVE: 'CASE_INSENSITIVE'>

Always case-insensitive, regardless of quotes.

Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
144class Dialect(metaclass=_Dialect):
145    INDEX_OFFSET = 0
146    """Determines the base index offset for arrays."""
147
148    WEEK_OFFSET = 0
149    """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
150
151    UNNEST_COLUMN_ONLY = False
152    """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
153
154    ALIAS_POST_TABLESAMPLE = False
155    """Determines whether or not the table alias comes after tablesample."""
156
157    NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
158    """Specifies the strategy according to which identifiers should be normalized."""
159
160    IDENTIFIERS_CAN_START_WITH_DIGIT = False
161    """Determines whether or not an unquoted identifier can start with a digit."""
162
163    DPIPE_IS_STRING_CONCAT = True
164    """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
165
166    STRICT_STRING_CONCAT = False
167    """Determines whether or not `CONCAT`'s arguments must be strings."""
168
169    SUPPORTS_USER_DEFINED_TYPES = True
170    """Determines whether or not user-defined data types are supported."""
171
172    SUPPORTS_SEMI_ANTI_JOIN = True
173    """Determines whether or not `SEMI` or `ANTI` joins are supported."""
174
175    NORMALIZE_FUNCTIONS: bool | str = "upper"
176    """Determines how function names are going to be normalized."""
177
178    LOG_BASE_FIRST = True
179    """Determines whether the base comes first in the `LOG` function."""
180
181    NULL_ORDERING = "nulls_are_small"
182    """
183    Indicates the default `NULL` ordering method to use if not explicitly set.
184    Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
185    """
186
187    TYPED_DIVISION = False
188    """
189    Whether the behavior of `a / b` depends on the types of `a` and `b`.
190    False means `a / b` is always float division.
191    True means `a / b` is integer division if both `a` and `b` are integers.
192    """
193
194    SAFE_DIVISION = False
195    """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
196
197    CONCAT_COALESCE = False
198    """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
199
200    DATE_FORMAT = "'%Y-%m-%d'"
201    DATEINT_FORMAT = "'%Y%m%d'"
202    TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
203
204    TIME_MAPPING: t.Dict[str, str] = {}
205    """Associates this dialect's time formats with their equivalent Python `strftime` format."""
206
207    # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
208    # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
209    FORMAT_MAPPING: t.Dict[str, str] = {}
210    """
211    Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
212    If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
213    """
214
215    ESCAPE_SEQUENCES: t.Dict[str, str] = {}
216    """Mapping of an unescaped escape sequence to the corresponding character."""
217
218    PSEUDOCOLUMNS: t.Set[str] = set()
219    """
220    Columns that are auto-generated by the engine corresponding to this dialect.
221    For example, such columns may be excluded from `SELECT *` queries.
222    """
223
224    PREFER_CTE_ALIAS_COLUMN = False
225    """
226    Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
227    HAVING clause of the CTE. This flag will cause the CTE alias columns to override
228    any projection aliases in the subquery.
229
230    For example,
231        WITH y(c) AS (
232            SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
233        ) SELECT c FROM y;
234
235        will be rewritten as
236
237        WITH y(c) AS (
238            SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
239        ) SELECT c FROM y;
240    """
241
242    # --- Autofilled ---
243
244    tokenizer_class = Tokenizer
245    parser_class = Parser
246    generator_class = Generator
247
248    # A trie of the time_mapping keys
249    TIME_TRIE: t.Dict = {}
250    FORMAT_TRIE: t.Dict = {}
251
252    INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
253    INVERSE_TIME_TRIE: t.Dict = {}
254
255    INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
256
257    # Delimiters for quotes, identifiers and the corresponding escape characters
258    QUOTE_START = "'"
259    QUOTE_END = "'"
260    IDENTIFIER_START = '"'
261    IDENTIFIER_END = '"'
262
263    # Delimiters for bit, hex, byte and unicode literals
264    BIT_START: t.Optional[str] = None
265    BIT_END: t.Optional[str] = None
266    HEX_START: t.Optional[str] = None
267    HEX_END: t.Optional[str] = None
268    BYTE_START: t.Optional[str] = None
269    BYTE_END: t.Optional[str] = None
270    UNICODE_START: t.Optional[str] = None
271    UNICODE_END: t.Optional[str] = None
272
273    @classmethod
274    def get_or_raise(cls, dialect: DialectType) -> Dialect:
275        """
276        Look up a dialect in the global dialect registry and return it if it exists.
277
278        Args:
279            dialect: The target dialect. If this is a string, it can be optionally followed by
280                additional key-value pairs that are separated by commas and are used to specify
281                dialect settings, such as whether the dialect's identifiers are case-sensitive.
282
283        Example:
284            >>> dialect = dialect_class = get_or_raise("duckdb")
285            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
286
287        Returns:
288            The corresponding Dialect instance.
289        """
290
291        if not dialect:
292            return cls()
293        if isinstance(dialect, _Dialect):
294            return dialect()
295        if isinstance(dialect, Dialect):
296            return dialect
297        if isinstance(dialect, str):
298            try:
299                dialect_name, *kv_pairs = dialect.split(",")
300                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
301            except ValueError:
302                raise ValueError(
303                    f"Invalid dialect format: '{dialect}'. "
304                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
305                )
306
307            result = cls.get(dialect_name.strip())
308            if not result:
309                raise ValueError(f"Unknown dialect '{dialect_name}'.")
310
311            return result(**kwargs)
312
313        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
314
315    @classmethod
316    def format_time(
317        cls, expression: t.Optional[str | exp.Expression]
318    ) -> t.Optional[exp.Expression]:
319        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
320        if isinstance(expression, str):
321            return exp.Literal.string(
322                # the time formats are quoted
323                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
324            )
325
326        if expression and expression.is_string:
327            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
328
329        return expression
330
331    def __init__(self, **kwargs) -> None:
332        normalization_strategy = kwargs.get("normalization_strategy")
333
334        if normalization_strategy is None:
335            self.normalization_strategy = self.NORMALIZATION_STRATEGY
336        else:
337            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
338
339    def __eq__(self, other: t.Any) -> bool:
340        # Does not currently take dialect state into account
341        return type(self) == other
342
343    def __hash__(self) -> int:
344        # Does not currently take dialect state into account
345        return hash(type(self))
346
347    def normalize_identifier(self, expression: E) -> E:
348        """
349        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
350
351        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
352        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
353        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
354        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
355
356        There are also dialects like Spark, which are case-insensitive even when quotes are
357        present, and dialects like MySQL, whose resolution rules match those employed by the
358        underlying operating system, for example they may always be case-sensitive in Linux.
359
360        Finally, the normalization behavior of some engines can even be controlled through flags,
361        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
362
363        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
364        that it can analyze queries in the optimizer and successfully capture their semantics.
365        """
366        if (
367            isinstance(expression, exp.Identifier)
368            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
369            and (
370                not expression.quoted
371                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
372            )
373        ):
374            expression.set(
375                "this",
376                expression.this.upper()
377                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
378                else expression.this.lower(),
379            )
380
381        return expression
382
383    def case_sensitive(self, text: str) -> bool:
384        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
385        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
386            return False
387
388        unsafe = (
389            str.islower
390            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
391            else str.isupper
392        )
393        return any(unsafe(char) for char in text)
394
395    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
396        """Checks if text can be identified given an identify option.
397
398        Args:
399            text: The text to check.
400            identify:
401                `"always"` or `True`: Always returns `True`.
402                `"safe"`: Only returns `True` if the identifier is case-insensitive.
403
404        Returns:
405            Whether or not the given text can be identified.
406        """
407        if identify is True or identify == "always":
408            return True
409
410        if identify == "safe":
411            return not self.case_sensitive(text)
412
413        return False
414
415    def quote_identifier(self, expression: E, identify: bool = True) -> E:
416        """
417        Adds quotes to a given identifier.
418
419        Args:
420            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
421            identify: If set to `False`, the quotes will only be added if the identifier is deemed
422                "unsafe", with respect to its characters and this dialect's normalization strategy.
423        """
424        if isinstance(expression, exp.Identifier):
425            name = expression.this
426            expression.set(
427                "quoted",
428                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
429            )
430
431        return expression
432
433    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
434        return self.parser(**opts).parse(self.tokenize(sql), sql)
435
436    def parse_into(
437        self, expression_type: exp.IntoType, sql: str, **opts
438    ) -> t.List[t.Optional[exp.Expression]]:
439        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
440
441    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
442        return self.generator(**opts).generate(expression, copy=copy)
443
444    def transpile(self, sql: str, **opts) -> t.List[str]:
445        return [
446            self.generate(expression, copy=False, **opts) if expression else ""
447            for expression in self.parse(sql)
448        ]
449
450    def tokenize(self, sql: str) -> t.List[Token]:
451        return self.tokenizer.tokenize(sql)
452
453    @property
454    def tokenizer(self) -> Tokenizer:
455        if not hasattr(self, "_tokenizer"):
456            self._tokenizer = self.tokenizer_class(dialect=self)
457        return self._tokenizer
458
459    def parser(self, **opts) -> Parser:
460        return self.parser_class(dialect=self, **opts)
461
462    def generator(self, **opts) -> Generator:
463        return self.generator_class(dialect=self, **opts)
Dialect(**kwargs)
331    def __init__(self, **kwargs) -> None:
332        normalization_strategy = kwargs.get("normalization_strategy")
333
334        if normalization_strategy is None:
335            self.normalization_strategy = self.NORMALIZATION_STRATEGY
336        else:
337            self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
INDEX_OFFSET = 0

Determines the base index offset for arrays.

WEEK_OFFSET = 0

Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.

UNNEST_COLUMN_ONLY = False

Determines whether or not UNNEST table aliases are treated as column aliases.

ALIAS_POST_TABLESAMPLE = False

Determines whether or not the table alias comes after tablesample.

NORMALIZATION_STRATEGY = <NormalizationStrategy.LOWERCASE: 'LOWERCASE'>

Specifies the strategy according to which identifiers should be normalized.

IDENTIFIERS_CAN_START_WITH_DIGIT = False

Determines whether or not an unquoted identifier can start with a digit.

DPIPE_IS_STRING_CONCAT = True

Determines whether or not the DPIPE token (||) is a string concatenation operator.

STRICT_STRING_CONCAT = False

Determines whether or not CONCAT's arguments must be strings.

SUPPORTS_USER_DEFINED_TYPES = True

Determines whether or not user-defined data types are supported.

SUPPORTS_SEMI_ANTI_JOIN = True

Determines whether or not SEMI or ANTI joins are supported.

NORMALIZE_FUNCTIONS: bool | str = 'upper'

Determines how function names are going to be normalized.

LOG_BASE_FIRST = True

Determines whether the base comes first in the LOG function.

NULL_ORDERING = 'nulls_are_small'

Indicates the default NULL ordering method to use if not explicitly set. Possible values: "nulls_are_small", "nulls_are_large", "nulls_are_last"

TYPED_DIVISION = False

Whether the behavior of a / b depends on the types of a and b. False means a / b is always float division. True means a / b is integer division if both a and b are integers.

SAFE_DIVISION = False

Determines whether division by zero throws an error (False) or returns NULL (True).

CONCAT_COALESCE = False

A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string.

DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: Dict[str, str] = {}

Associates this dialect's time formats with their equivalent Python strftime format.

FORMAT_MAPPING: Dict[str, str] = {}

Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy'). If empty, the corresponding trie will be constructed off of TIME_MAPPING.

ESCAPE_SEQUENCES: Dict[str, str] = {}

Mapping of an unescaped escape sequence to the corresponding character.

PSEUDOCOLUMNS: Set[str] = set()

Columns that are auto-generated by the engine corresponding to this dialect. For example, such columns may be excluded from SELECT * queries.

PREFER_CTE_ALIAS_COLUMN = False

Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.

For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
    SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
tokenizer_class = <class 'sqlglot.tokens.Tokenizer'>
parser_class = <class 'sqlglot.parser.Parser'>
generator_class = <class 'sqlglot.generator.Generator'>
TIME_TRIE: Dict = {}
FORMAT_TRIE: Dict = {}
INVERSE_TIME_MAPPING: Dict[str, str] = {}
INVERSE_TIME_TRIE: Dict = {}
INVERSE_ESCAPE_SEQUENCES: Dict[str, str] = {}
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
BIT_START: Optional[str] = None
BIT_END: Optional[str] = None
HEX_START: Optional[str] = None
HEX_END: Optional[str] = None
BYTE_START: Optional[str] = None
BYTE_END: Optional[str] = None
UNICODE_START: Optional[str] = None
UNICODE_END: Optional[str] = None
@classmethod
def get_or_raise( cls, dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> Dialect:
273    @classmethod
274    def get_or_raise(cls, dialect: DialectType) -> Dialect:
275        """
276        Look up a dialect in the global dialect registry and return it if it exists.
277
278        Args:
279            dialect: The target dialect. If this is a string, it can be optionally followed by
280                additional key-value pairs that are separated by commas and are used to specify
281                dialect settings, such as whether the dialect's identifiers are case-sensitive.
282
283        Example:
284            >>> dialect = dialect_class = get_or_raise("duckdb")
285            >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
286
287        Returns:
288            The corresponding Dialect instance.
289        """
290
291        if not dialect:
292            return cls()
293        if isinstance(dialect, _Dialect):
294            return dialect()
295        if isinstance(dialect, Dialect):
296            return dialect
297        if isinstance(dialect, str):
298            try:
299                dialect_name, *kv_pairs = dialect.split(",")
300                kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
301            except ValueError:
302                raise ValueError(
303                    f"Invalid dialect format: '{dialect}'. "
304                    "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
305                )
306
307            result = cls.get(dialect_name.strip())
308            if not result:
309                raise ValueError(f"Unknown dialect '{dialect_name}'.")
310
311            return result(**kwargs)
312
313        raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")

Look up a dialect in the global dialect registry and return it if it exists.

Arguments:
  • dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb")
>>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:

The corresponding Dialect instance.

@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
315    @classmethod
316    def format_time(
317        cls, expression: t.Optional[str | exp.Expression]
318    ) -> t.Optional[exp.Expression]:
319        """Converts a time format in this dialect to its equivalent Python `strftime` format."""
320        if isinstance(expression, str):
321            return exp.Literal.string(
322                # the time formats are quoted
323                format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
324            )
325
326        if expression and expression.is_string:
327            return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
328
329        return expression

Converts a time format in this dialect to its equivalent Python strftime format.

def normalize_identifier(self, expression: ~E) -> ~E:
347    def normalize_identifier(self, expression: E) -> E:
348        """
349        Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
350
351        For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
352        lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
353        it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
354        and so any normalization would be prohibited in order to avoid "breaking" the identifier.
355
356        There are also dialects like Spark, which are case-insensitive even when quotes are
357        present, and dialects like MySQL, whose resolution rules match those employed by the
358        underlying operating system, for example they may always be case-sensitive in Linux.
359
360        Finally, the normalization behavior of some engines can even be controlled through flags,
361        like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
362
363        SQLGlot aims to understand and handle all of these different behaviors gracefully, so
364        that it can analyze queries in the optimizer and successfully capture their semantics.
365        """
366        if (
367            isinstance(expression, exp.Identifier)
368            and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
369            and (
370                not expression.quoted
371                or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
372            )
373        ):
374            expression.set(
375                "this",
376                expression.this.upper()
377                if self.normalization_strategy is NormalizationStrategy.UPPERCASE
378                else expression.this.lower(),
379            )
380
381        return expression

Transforms an identifier in a way that resembles how it'd be resolved by this dialect.

For example, an identifier like FoO would be resolved as foo in Postgres, because it lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, and so any normalization would be prohibited in order to avoid "breaking" the identifier.

There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.

Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.

SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.

def case_sensitive(self, text: str) -> bool:
383    def case_sensitive(self, text: str) -> bool:
384        """Checks if text contains any case sensitive characters, based on the dialect's rules."""
385        if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
386            return False
387
388        unsafe = (
389            str.islower
390            if self.normalization_strategy is NormalizationStrategy.UPPERCASE
391            else str.isupper
392        )
393        return any(unsafe(char) for char in text)

Checks if text contains any case sensitive characters, based on the dialect's rules.

def can_identify(self, text: str, identify: str | bool = 'safe') -> bool:
395    def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
396        """Checks if text can be identified given an identify option.
397
398        Args:
399            text: The text to check.
400            identify:
401                `"always"` or `True`: Always returns `True`.
402                `"safe"`: Only returns `True` if the identifier is case-insensitive.
403
404        Returns:
405            Whether or not the given text can be identified.
406        """
407        if identify is True or identify == "always":
408            return True
409
410        if identify == "safe":
411            return not self.case_sensitive(text)
412
413        return False

Checks if text can be identified given an identify option.

Arguments:
  • text: The text to check.
  • identify: "always" or True: Always returns True. "safe": Only returns True if the identifier is case-insensitive.
Returns:

Whether or not the given text can be identified.

def quote_identifier(self, expression: ~E, identify: bool = True) -> ~E:
415    def quote_identifier(self, expression: E, identify: bool = True) -> E:
416        """
417        Adds quotes to a given identifier.
418
419        Args:
420            expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
421            identify: If set to `False`, the quotes will only be added if the identifier is deemed
422                "unsafe", with respect to its characters and this dialect's normalization strategy.
423        """
424        if isinstance(expression, exp.Identifier):
425            name = expression.this
426            expression.set(
427                "quoted",
428                identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
429            )
430
431        return expression

Adds quotes to a given identifier.

Arguments:
  • expression: The expression of interest. If it's not an Identifier, this method is a no-op.
  • identify: If set to False, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
433    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
434        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
436    def parse_into(
437        self, expression_type: exp.IntoType, sql: str, **opts
438    ) -> t.List[t.Optional[exp.Expression]]:
439        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: sqlglot.expressions.Expression, copy: bool = True, **opts) -> str:
441    def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
442        return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> List[str]:
444    def transpile(self, sql: str, **opts) -> t.List[str]:
445        return [
446            self.generate(expression, copy=False, **opts) if expression else ""
447            for expression in self.parse(sql)
448        ]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
450    def tokenize(self, sql: str) -> t.List[Token]:
451        return self.tokenizer.tokenize(sql)
tokenizer: sqlglot.tokens.Tokenizer
453    @property
454    def tokenizer(self) -> Tokenizer:
455        if not hasattr(self, "_tokenizer"):
456            self._tokenizer = self.tokenizer_class(dialect=self)
457        return self._tokenizer
def parser(self, **opts) -> sqlglot.parser.Parser:
459    def parser(self, **opts) -> Parser:
460        return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> sqlglot.generator.Generator:
462    def generator(self, **opts) -> Generator:
463        return self.generator_class(dialect=self, **opts)
DialectType = typing.Union[str, Dialect, typing.Type[Dialect], NoneType]
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
469def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
470    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
473def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
474    if expression.args.get("accuracy"):
475        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
476    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( name: str = 'IF', false_value: Union[str, sqlglot.expressions.Expression, NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.If], str]:
479def if_sql(
480    name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
481) -> t.Callable[[Generator, exp.If], str]:
482    def _if_sql(self: Generator, expression: exp.If) -> str:
483        return self.func(
484            name,
485            expression.this,
486            expression.args.get("true"),
487            expression.args.get("false") or false_value,
488        )
489
490    return _if_sql
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
493def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
494    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
497def arrow_json_extract_scalar_sql(
498    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
499) -> str:
500    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
503def inline_array_sql(self: Generator, expression: exp.Array) -> str:
504    return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
507def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
508    return self.like_sql(
509        exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
510    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
513def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
514    zone = self.sql(expression, "this")
515    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
518def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
519    if expression.args.get("recursive"):
520        self.unsupported("Recursive CTEs are unsupported")
521        expression.args["recursive"] = False
522    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
525def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
526    n = self.sql(expression, "this")
527    d = self.sql(expression, "expression")
528    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
531def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
532    self.unsupported("TABLESAMPLE unsupported")
533    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
536def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
537    self.unsupported("PIVOT unsupported")
538    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
541def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
542    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
545def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
546    self.unsupported("Properties unsupported")
547    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
550def no_comment_column_constraint_sql(
551    self: Generator, expression: exp.CommentColumnConstraint
552) -> str:
553    self.unsupported("CommentColumnConstraint unsupported")
554    return ""
def no_map_from_entries_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.MapFromEntries) -> str:
557def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
558    self.unsupported("MAP_FROM_ENTRIES unsupported")
559    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
562def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
563    this = self.sql(expression, "this")
564    substr = self.sql(expression, "substr")
565    position = self.sql(expression, "position")
566    if position:
567        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
568    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
571def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
572    return (
573        f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
574    )
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
577def var_map_sql(
578    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
579) -> str:
580    keys = expression.args["keys"]
581    values = expression.args["values"]
582
583    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
584        self.unsupported("Cannot convert array columns into map.")
585        return self.func(map_func_name, keys, values)
586
587    args = []
588    for key, value in zip(keys.expressions, values.expressions):
589        args.append(self.sql(key))
590        args.append(self.sql(value))
591
592    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[str, bool, NoneType] = None) -> Callable[[List], ~E]:
595def format_time_lambda(
596    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
597) -> t.Callable[[t.List], E]:
598    """Helper used for time expressions.
599
600    Args:
601        exp_class: the expression class to instantiate.
602        dialect: target sql dialect.
603        default: the default format, True being time.
604
605    Returns:
606        A callable that can be used to return the appropriately formatted time expression.
607    """
608
609    def _format_time(args: t.List):
610        return exp_class(
611            this=seq_get(args, 0),
612            format=Dialect[dialect].format_time(
613                seq_get(args, 1)
614                or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
615            ),
616        )
617
618    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def time_format( dialect: Union[str, Dialect, Type[Dialect], NoneType] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.UnixToStr | sqlglot.expressions.StrToUnix], Optional[str]]:
621def time_format(
622    dialect: DialectType = None,
623) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
624    def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
625        """
626        Returns the time format for a given expression, unless it's equivalent
627        to the default time format of the dialect of interest.
628        """
629        time_format = self.format_time(expression)
630        return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
631
632    return _time_format
def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
635def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
636    """
637    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
638    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
639    columns are removed from the create statement.
640    """
641    has_schema = isinstance(expression.this, exp.Schema)
642    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
643
644    if has_schema and is_partitionable:
645        prop = expression.find(exp.PartitionedByProperty)
646        if prop and prop.this and not isinstance(prop.this, exp.Schema):
647            schema = expression.this
648            columns = {v.name.upper() for v in prop.this.expressions}
649            partitions = [col for col in schema.expressions if col.name.upper() in columns]
650            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
651            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
652            expression.set("this", schema)
653
654    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
657def parse_date_delta(
658    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
659) -> t.Callable[[t.List], E]:
660    def inner_func(args: t.List) -> E:
661        unit_based = len(args) == 3
662        this = args[2] if unit_based else seq_get(args, 0)
663        unit = args[0] if unit_based else exp.Literal.string("DAY")
664        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
665        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
666
667    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
670def parse_date_delta_with_interval(
671    expression_class: t.Type[E],
672) -> t.Callable[[t.List], t.Optional[E]]:
673    def func(args: t.List) -> t.Optional[E]:
674        if len(args) < 2:
675            return None
676
677        interval = args[1]
678
679        if not isinstance(interval, exp.Interval):
680            raise ParseError(f"INTERVAL expression expected but got '{interval}'")
681
682        expression = interval.this
683        if expression and expression.is_string:
684            expression = exp.Literal.number(expression.this)
685
686        return expression_class(
687            this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
688        )
689
690    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
693def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
694    unit = seq_get(args, 0)
695    this = seq_get(args, 1)
696
697    if isinstance(this, exp.Cast) and this.is_type("date"):
698        return exp.DateTrunc(unit=unit, this=this)
699    return exp.TimestampTrunc(this=this, unit=unit)
def date_add_interval_sql( data_type: str, kind: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
702def date_add_interval_sql(
703    data_type: str, kind: str
704) -> t.Callable[[Generator, exp.Expression], str]:
705    def func(self: Generator, expression: exp.Expression) -> str:
706        this = self.sql(expression, "this")
707        unit = expression.args.get("unit")
708        unit = exp.var(unit.name.upper() if unit else "DAY")
709        interval = exp.Interval(this=expression.expression, unit=unit)
710        return f"{data_type}_{kind}({this}, {self.sql(interval)})"
711
712    return func
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
715def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
716    return self.func(
717        "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
718    )
def no_timestamp_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Timestamp) -> str:
721def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
722    if not expression.expression:
723        return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
724    if expression.text("expression").lower() in TIMEZONES:
725        return self.sql(
726            exp.AtTimeZone(
727                this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
728                zone=expression.expression,
729            )
730        )
731    return self.function_fallback_sql(expression)
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
734def locate_to_strposition(args: t.List) -> exp.Expression:
735    return exp.StrPosition(
736        this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
737    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
740def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
741    return self.func(
742        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
743    )
def left_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
746def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
747    return self.sql(
748        exp.Substring(
749            this=expression.this, start=exp.Literal.number(1), length=expression.expression
750        )
751    )
def right_to_substring_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Left) -> str:
754def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
755    return self.sql(
756        exp.Substring(
757            this=expression.this,
758            start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
759        )
760    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
763def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
764    return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
767def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
768    return self.sql(exp.cast(expression.this, "date"))
def encode_decode_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression, name: str, replace: bool = True) -> str:
772def encode_decode_sql(
773    self: Generator, expression: exp.Expression, name: str, replace: bool = True
774) -> str:
775    charset = expression.args.get("charset")
776    if charset and charset.name.lower() != "utf-8":
777        self.unsupported(f"Expected utf-8 character set, got {charset}.")
778
779    return self.func(name, expression.this, expression.args.get("replace") if replace else None)
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
782def min_or_least(self: Generator, expression: exp.Min) -> str:
783    name = "LEAST" if expression.expressions else "MIN"
784    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
787def max_or_greatest(self: Generator, expression: exp.Max) -> str:
788    name = "GREATEST" if expression.expressions else "MAX"
789    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
792def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
793    cond = expression.this
794
795    if isinstance(expression.this, exp.Distinct):
796        cond = expression.this.expressions[0]
797        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
798
799    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
802def trim_sql(self: Generator, expression: exp.Trim) -> str:
803    target = self.sql(expression, "this")
804    trim_type = self.sql(expression, "position")
805    remove_chars = self.sql(expression, "expression")
806    collation = self.sql(expression, "collation")
807
808    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
809    if not remove_chars and not collation:
810        return self.trim_sql(expression)
811
812    trim_type = f"{trim_type} " if trim_type else ""
813    remove_chars = f"{remove_chars} " if remove_chars else ""
814    from_part = "FROM " if trim_type or remove_chars else ""
815    collation = f" COLLATE {collation}" if collation else ""
816    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
819def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
820    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
823def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
824    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
825        _dialect = Dialect.get_or_raise(dialect)
826        time_format = self.format_time(expression)
827        if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
828            return self.sql(
829                exp.cast(
830                    exp.StrToTime(this=expression.this, format=expression.args["format"]),
831                    "date",
832                )
833            )
834        return self.sql(exp.cast(expression.this, "date"))
835
836    return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Concat) -> str:
839def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
840    return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ConcatWs) -> str:
843def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
844    delim, *rest_args = expression.expressions
845    return self.sql(
846        reduce(
847            lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
848            rest_args,
849        )
850    )
def regexp_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpExtract) -> str:
853def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
854    bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
855    if bad_args:
856        self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}")
857
858    return self.func(
859        "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group")
860    )
def regexp_replace_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.RegexpReplace) -> str:
863def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
864    bad_args = list(
865        filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
866    )
867    if bad_args:
868        self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
869
870    return self.func(
871        "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
872    )
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, Dialect, Type[Dialect], NoneType]) -> List[str]:
875def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
876    names = []
877    for agg in aggregations:
878        if isinstance(agg, exp.Alias):
879            names.append(agg.alias)
880        else:
881            """
882            This case corresponds to aggregations without aliases being used as suffixes
883            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
884            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
885            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
886            """
887            agg_all_unquoted = agg.transform(
888                lambda node: exp.Identifier(this=node.name, quoted=False)
889                if isinstance(node, exp.Identifier)
890                else node
891            )
892            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
893
894    return names
def binary_from_function(expr_type: Type[~B]) -> Callable[[List], ~B]:
897def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
898    return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
def parse_timestamp_trunc(args: List) -> sqlglot.expressions.TimestampTrunc:
902def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
903    return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.AnyValue) -> str:
906def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
907    return self.func("MAX", expression.this)
def bool_xor_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Xor) -> str:
910def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
911    a = self.sql(expression.left)
912    b = self.sql(expression.right)
913    return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
def json_keyvalue_comma_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONKeyValue) -> str:
917def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
918    return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
def is_parse_json(expression: sqlglot.expressions.Expression) -> bool:
921def is_parse_json(expression: exp.Expression) -> bool:
922    return isinstance(expression, exp.ParseJSON) or (
923        isinstance(expression, exp.Cast) and expression.is_type("json")
924    )
def isnull_to_is_null(args: List) -> sqlglot.expressions.Expression:
927def isnull_to_is_null(args: t.List) -> exp.Expression:
928    return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def generatedasidentitycolumnconstraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.GeneratedAsIdentityColumnConstraint) -> str:
931def generatedasidentitycolumnconstraint_sql(
932    self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
933) -> str:
934    start = self.sql(expression, "start") or "1"
935    increment = self.sql(expression, "increment") or "1"
936    return f"IDENTITY({start}, {increment})"
def arg_max_or_min_no_count( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.ArgMax | sqlglot.expressions.ArgMin], str]:
939def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
940    def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
941        if expression.args.get("count"):
942            self.unsupported(f"Only two arguments are supported in function {name}.")
943
944        return self.func(name, expression.this, expression.expression)
945
946    return _arg_max_or_min_sql
def ts_or_ds_add_cast( expression: sqlglot.expressions.TsOrDsAdd) -> sqlglot.expressions.TsOrDsAdd:
949def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
950    this = expression.this.copy()
951
952    return_type = expression.return_type
953    if return_type.is_type(exp.DataType.Type.DATE):
954        # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
955        # can truncate timestamp strings, because some dialects can't cast them to DATE
956        this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
957
958    expression.this.replace(exp.cast(this, return_type))
959    return expression
def date_delta_sql( name: str, cast: bool = False) -> Callable[[sqlglot.generator.Generator, Union[sqlglot.expressions.DateAdd, sqlglot.expressions.TsOrDsAdd, sqlglot.expressions.DateDiff, sqlglot.expressions.TsOrDsDiff]], str]:
962def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
963    def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
964        if cast and isinstance(expression, exp.TsOrDsAdd):
965            expression = ts_or_ds_add_cast(expression)
966
967        return self.func(
968            name,
969            exp.var(expression.text("unit").upper() or "DAY"),
970            expression.expression,
971            expression.this,
972        )
973
974    return _delta_sql