Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer, TokenType
 12from sqlglot.trie import new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
 19# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
 20RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}
 21
 22
 23class Dialects(str, Enum):
 24    DIALECT = ""
 25
 26    BIGQUERY = "bigquery"
 27    CLICKHOUSE = "clickhouse"
 28    DUCKDB = "duckdb"
 29    HIVE = "hive"
 30    MYSQL = "mysql"
 31    ORACLE = "oracle"
 32    POSTGRES = "postgres"
 33    PRESTO = "presto"
 34    REDSHIFT = "redshift"
 35    SNOWFLAKE = "snowflake"
 36    SPARK = "spark"
 37    SPARK2 = "spark2"
 38    SQLITE = "sqlite"
 39    STARROCKS = "starrocks"
 40    TABLEAU = "tableau"
 41    TRINO = "trino"
 42    TSQL = "tsql"
 43    DATABRICKS = "databricks"
 44    DRILL = "drill"
 45    TERADATA = "teradata"
 46
 47
 48class _Dialect(type):
 49    classes: t.Dict[str, t.Type[Dialect]] = {}
 50
 51    def __eq__(cls, other: t.Any) -> bool:
 52        if cls is other:
 53            return True
 54        if isinstance(other, str):
 55            return cls is cls.get(other)
 56        if isinstance(other, Dialect):
 57            return cls is type(other)
 58
 59        return False
 60
 61    def __hash__(cls) -> int:
 62        return hash(cls.__name__.lower())
 63
 64    @classmethod
 65    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 66        return cls.classes[key]
 67
 68    @classmethod
 69    def get(
 70        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 71    ) -> t.Optional[t.Type[Dialect]]:
 72        return cls.classes.get(key, default)
 73
 74    def __new__(cls, clsname, bases, attrs):
 75        klass = super().__new__(cls, clsname, bases, attrs)
 76        enum = Dialects.__members__.get(clsname.upper())
 77        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 78
 79        klass.time_trie = new_trie(klass.time_mapping)
 80        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 81        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 82
 83        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 84        klass.parser_class = getattr(klass, "Parser", Parser)
 85        klass.generator_class = getattr(klass, "Generator", Generator)
 86
 87        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 88        klass.identifier_start, klass.identifier_end = list(
 89            klass.tokenizer_class._IDENTIFIERS.items()
 90        )[0]
 91
 92        def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]:
 93            return next(
 94                (
 95                    (s, e)
 96                    for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items()
 97                    if t == token_type
 98                ),
 99                (None, None),
100            )
101
102        klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
103        klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
104        klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
105        klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
106
107        return klass
108
109
110class Dialect(metaclass=_Dialect):
111    index_offset = 0
112    unnest_column_only = False
113    alias_post_tablesample = False
114    normalize_functions: t.Optional[str] = "upper"
115    null_ordering = "nulls_are_small"
116
117    date_format = "'%Y-%m-%d'"
118    dateint_format = "'%Y%m%d'"
119    time_format = "'%Y-%m-%d %H:%M:%S'"
120    time_mapping: t.Dict[str, str] = {}
121
122    # autofilled
123    quote_start = None
124    quote_end = None
125    identifier_start = None
126    identifier_end = None
127
128    time_trie = None
129    inverse_time_mapping = None
130    inverse_time_trie = None
131    tokenizer_class = None
132    parser_class = None
133    generator_class = None
134
135    def __eq__(self, other: t.Any) -> bool:
136        return type(self) == other
137
138    def __hash__(self) -> int:
139        return hash(type(self))
140
141    @classmethod
142    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
143        if not dialect:
144            return cls
145        if isinstance(dialect, _Dialect):
146            return dialect
147        if isinstance(dialect, Dialect):
148            return dialect.__class__
149
150        result = cls.get(dialect)
151        if not result:
152            raise ValueError(f"Unknown dialect '{dialect}'")
153
154        return result
155
156    @classmethod
157    def format_time(
158        cls, expression: t.Optional[str | exp.Expression]
159    ) -> t.Optional[exp.Expression]:
160        if isinstance(expression, str):
161            return exp.Literal.string(
162                format_time(
163                    expression[1:-1],  # the time formats are quoted
164                    cls.time_mapping,
165                    cls.time_trie,
166                )
167            )
168        if expression and expression.is_string:
169            return exp.Literal.string(
170                format_time(
171                    expression.this,
172                    cls.time_mapping,
173                    cls.time_trie,
174                )
175            )
176        return expression
177
178    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
179        return self.parser(**opts).parse(self.tokenize(sql), sql)
180
181    def parse_into(
182        self, expression_type: exp.IntoType, sql: str, **opts
183    ) -> t.List[t.Optional[exp.Expression]]:
184        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
185
186    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
187        return self.generator(**opts).generate(expression)
188
189    def transpile(self, sql: str, **opts) -> t.List[str]:
190        return [self.generate(expression, **opts) for expression in self.parse(sql)]
191
192    def tokenize(self, sql: str) -> t.List[Token]:
193        return self.tokenizer.tokenize(sql)
194
195    @property
196    def tokenizer(self) -> Tokenizer:
197        if not hasattr(self, "_tokenizer"):
198            self._tokenizer = self.tokenizer_class()  # type: ignore
199        return self._tokenizer
200
201    def parser(self, **opts) -> Parser:
202        return self.parser_class(  # type: ignore
203            **{
204                "index_offset": self.index_offset,
205                "unnest_column_only": self.unnest_column_only,
206                "alias_post_tablesample": self.alias_post_tablesample,
207                "null_ordering": self.null_ordering,
208                **opts,
209            },
210        )
211
212    def generator(self, **opts) -> Generator:
213        return self.generator_class(  # type: ignore
214            **{
215                "quote_start": self.quote_start,
216                "quote_end": self.quote_end,
217                "bit_start": self.bit_start,
218                "bit_end": self.bit_end,
219                "hex_start": self.hex_start,
220                "hex_end": self.hex_end,
221                "byte_start": self.byte_start,
222                "byte_end": self.byte_end,
223                "raw_start": self.raw_start,
224                "raw_end": self.raw_end,
225                "identifier_start": self.identifier_start,
226                "identifier_end": self.identifier_end,
227                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
228                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
229                "index_offset": self.index_offset,
230                "time_mapping": self.inverse_time_mapping,
231                "time_trie": self.inverse_time_trie,
232                "unnest_column_only": self.unnest_column_only,
233                "alias_post_tablesample": self.alias_post_tablesample,
234                "normalize_functions": self.normalize_functions,
235                "null_ordering": self.null_ordering,
236                **opts,
237            }
238        )
239
240
241DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
242
243
244def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
245    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
246
247
248def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
249    if expression.args.get("accuracy"):
250        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
251    return self.func("APPROX_COUNT_DISTINCT", expression.this)
252
253
254def if_sql(self: Generator, expression: exp.If) -> str:
255    return self.func(
256        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
257    )
258
259
260def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
261    return self.binary(expression, "->")
262
263
264def arrow_json_extract_scalar_sql(
265    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
266) -> str:
267    return self.binary(expression, "->>")
268
269
270def inline_array_sql(self: Generator, expression: exp.Array) -> str:
271    return f"[{self.expressions(expression)}]"
272
273
274def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
275    return self.like_sql(
276        exp.Like(
277            this=exp.Lower(this=expression.this),
278            expression=expression.args["expression"],
279        )
280    )
281
282
283def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
284    zone = self.sql(expression, "this")
285    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
286
287
288def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
289    if expression.args.get("recursive"):
290        self.unsupported("Recursive CTEs are unsupported")
291        expression.args["recursive"] = False
292    return self.with_sql(expression)
293
294
295def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
296    n = self.sql(expression, "this")
297    d = self.sql(expression, "expression")
298    return f"IF({d} <> 0, {n} / {d}, NULL)"
299
300
301def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
302    self.unsupported("TABLESAMPLE unsupported")
303    return self.sql(expression.this)
304
305
306def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
307    self.unsupported("PIVOT unsupported")
308    return ""
309
310
311def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
312    return self.cast_sql(expression)
313
314
315def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
316    self.unsupported("Properties unsupported")
317    return ""
318
319
320def no_comment_column_constraint_sql(
321    self: Generator, expression: exp.CommentColumnConstraint
322) -> str:
323    self.unsupported("CommentColumnConstraint unsupported")
324    return ""
325
326
327def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
328    this = self.sql(expression, "this")
329    substr = self.sql(expression, "substr")
330    position = self.sql(expression, "position")
331    if position:
332        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
333    return f"STRPOS({this}, {substr})"
334
335
336def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
337    this = self.sql(expression, "this")
338    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
339    return f"{this}.{struct_key}"
340
341
342def var_map_sql(
343    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
344) -> str:
345    keys = expression.args["keys"]
346    values = expression.args["values"]
347
348    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
349        self.unsupported("Cannot convert array columns into map.")
350        return self.func(map_func_name, keys, values)
351
352    args = []
353    for key, value in zip(keys.expressions, values.expressions):
354        args.append(self.sql(key))
355        args.append(self.sql(value))
356    return self.func(map_func_name, *args)
357
358
359def format_time_lambda(
360    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
361) -> t.Callable[[t.List], E]:
362    """Helper used for time expressions.
363
364    Args:
365        exp_class: the expression class to instantiate.
366        dialect: target sql dialect.
367        default: the default format, True being time.
368
369    Returns:
370        A callable that can be used to return the appropriately formatted time expression.
371    """
372
373    def _format_time(args: t.List):
374        return exp_class(
375            this=seq_get(args, 0),
376            format=Dialect[dialect].format_time(
377                seq_get(args, 1)
378                or (Dialect[dialect].time_format if default is True else default or None)
379            ),
380        )
381
382    return _format_time
383
384
385def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
386    """
387    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
388    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
389    columns are removed from the create statement.
390    """
391    has_schema = isinstance(expression.this, exp.Schema)
392    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
393
394    if has_schema and is_partitionable:
395        expression = expression.copy()
396        prop = expression.find(exp.PartitionedByProperty)
397        if prop and prop.this and not isinstance(prop.this, exp.Schema):
398            schema = expression.this
399            columns = {v.name.upper() for v in prop.this.expressions}
400            partitions = [col for col in schema.expressions if col.name.upper() in columns]
401            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
402            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
403            expression.set("this", schema)
404
405    return self.create_sql(expression)
406
407
408def parse_date_delta(
409    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
410) -> t.Callable[[t.List], E]:
411    def inner_func(args: t.List) -> E:
412        unit_based = len(args) == 3
413        this = args[2] if unit_based else seq_get(args, 0)
414        unit = args[0] if unit_based else exp.Literal.string("DAY")
415        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
416        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
417
418    return inner_func
419
420
421def parse_date_delta_with_interval(
422    expression_class: t.Type[E],
423) -> t.Callable[[t.List], t.Optional[E]]:
424    def func(args: t.List) -> t.Optional[E]:
425        if len(args) < 2:
426            return None
427
428        interval = args[1]
429        expression = interval.this
430        if expression and expression.is_string:
431            expression = exp.Literal.number(expression.this)
432
433        return expression_class(
434            this=args[0],
435            expression=expression,
436            unit=exp.Literal.string(interval.text("unit")),
437        )
438
439    return func
440
441
442def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
443    unit = seq_get(args, 0)
444    this = seq_get(args, 1)
445
446    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
447        return exp.DateTrunc(unit=unit, this=this)
448    return exp.TimestampTrunc(this=this, unit=unit)
449
450
451def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
452    return self.func(
453        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
454    )
455
456
457def locate_to_strposition(args: t.List) -> exp.Expression:
458    return exp.StrPosition(
459        this=seq_get(args, 1),
460        substr=seq_get(args, 0),
461        position=seq_get(args, 2),
462    )
463
464
465def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
466    return self.func(
467        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
468    )
469
470
471def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
472    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
473
474
475def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
476    return f"CAST({self.sql(expression, 'this')} AS DATE)"
477
478
479def min_or_least(self: Generator, expression: exp.Min) -> str:
480    name = "LEAST" if expression.expressions else "MIN"
481    return rename_func(name)(self, expression)
482
483
484def max_or_greatest(self: Generator, expression: exp.Max) -> str:
485    name = "GREATEST" if expression.expressions else "MAX"
486    return rename_func(name)(self, expression)
487
488
489def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
490    cond = expression.this
491
492    if isinstance(expression.this, exp.Distinct):
493        cond = expression.this.expressions[0]
494        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
495
496    return self.func("sum", exp.func("if", cond, 1, 0))
497
498
499def trim_sql(self: Generator, expression: exp.Trim) -> str:
500    target = self.sql(expression, "this")
501    trim_type = self.sql(expression, "position")
502    remove_chars = self.sql(expression, "expression")
503    collation = self.sql(expression, "collation")
504
505    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
506    if not remove_chars and not collation:
507        return self.trim_sql(expression)
508
509    trim_type = f"{trim_type} " if trim_type else ""
510    remove_chars = f"{remove_chars} " if remove_chars else ""
511    from_part = "FROM " if trim_type or remove_chars else ""
512    collation = f" COLLATE {collation}" if collation else ""
513    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
514
515
516def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
517    return self.func("STRPTIME", expression.this, self.format_time(expression))
518
519
520def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
521    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
522        _dialect = Dialect.get_or_raise(dialect)
523        time_format = self.format_time(expression)
524        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
525            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
526        return f"CAST({self.sql(expression, 'this')} AS DATE)"
527
528    return _ts_or_ds_to_date_sql
529
530
531# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
532def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
533    names = []
534    for agg in aggregations:
535        if isinstance(agg, exp.Alias):
536            names.append(agg.alias)
537        else:
538            """
539            This case corresponds to aggregations without aliases being used as suffixes
540            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
541            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
542            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
543            """
544            agg_all_unquoted = agg.transform(
545                lambda node: exp.Identifier(this=node.name, quoted=False)
546                if isinstance(node, exp.Identifier)
547                else node
548            )
549            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
550
551    return names
class Dialects(builtins.str, enum.Enum):
24class Dialects(str, Enum):
25    DIALECT = ""
26
27    BIGQUERY = "bigquery"
28    CLICKHOUSE = "clickhouse"
29    DUCKDB = "duckdb"
30    HIVE = "hive"
31    MYSQL = "mysql"
32    ORACLE = "oracle"
33    POSTGRES = "postgres"
34    PRESTO = "presto"
35    REDSHIFT = "redshift"
36    SNOWFLAKE = "snowflake"
37    SPARK = "spark"
38    SPARK2 = "spark2"
39    SQLITE = "sqlite"
40    STARROCKS = "starrocks"
41    TABLEAU = "tableau"
42    TRINO = "trino"
43    TSQL = "tsql"
44    DATABRICKS = "databricks"
45    DRILL = "drill"
46    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
111class Dialect(metaclass=_Dialect):
112    index_offset = 0
113    unnest_column_only = False
114    alias_post_tablesample = False
115    normalize_functions: t.Optional[str] = "upper"
116    null_ordering = "nulls_are_small"
117
118    date_format = "'%Y-%m-%d'"
119    dateint_format = "'%Y%m%d'"
120    time_format = "'%Y-%m-%d %H:%M:%S'"
121    time_mapping: t.Dict[str, str] = {}
122
123    # autofilled
124    quote_start = None
125    quote_end = None
126    identifier_start = None
127    identifier_end = None
128
129    time_trie = None
130    inverse_time_mapping = None
131    inverse_time_trie = None
132    tokenizer_class = None
133    parser_class = None
134    generator_class = None
135
136    def __eq__(self, other: t.Any) -> bool:
137        return type(self) == other
138
139    def __hash__(self) -> int:
140        return hash(type(self))
141
142    @classmethod
143    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
144        if not dialect:
145            return cls
146        if isinstance(dialect, _Dialect):
147            return dialect
148        if isinstance(dialect, Dialect):
149            return dialect.__class__
150
151        result = cls.get(dialect)
152        if not result:
153            raise ValueError(f"Unknown dialect '{dialect}'")
154
155        return result
156
157    @classmethod
158    def format_time(
159        cls, expression: t.Optional[str | exp.Expression]
160    ) -> t.Optional[exp.Expression]:
161        if isinstance(expression, str):
162            return exp.Literal.string(
163                format_time(
164                    expression[1:-1],  # the time formats are quoted
165                    cls.time_mapping,
166                    cls.time_trie,
167                )
168            )
169        if expression and expression.is_string:
170            return exp.Literal.string(
171                format_time(
172                    expression.this,
173                    cls.time_mapping,
174                    cls.time_trie,
175                )
176            )
177        return expression
178
179    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
180        return self.parser(**opts).parse(self.tokenize(sql), sql)
181
182    def parse_into(
183        self, expression_type: exp.IntoType, sql: str, **opts
184    ) -> t.List[t.Optional[exp.Expression]]:
185        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
186
187    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
188        return self.generator(**opts).generate(expression)
189
190    def transpile(self, sql: str, **opts) -> t.List[str]:
191        return [self.generate(expression, **opts) for expression in self.parse(sql)]
192
193    def tokenize(self, sql: str) -> t.List[Token]:
194        return self.tokenizer.tokenize(sql)
195
196    @property
197    def tokenizer(self) -> Tokenizer:
198        if not hasattr(self, "_tokenizer"):
199            self._tokenizer = self.tokenizer_class()  # type: ignore
200        return self._tokenizer
201
202    def parser(self, **opts) -> Parser:
203        return self.parser_class(  # type: ignore
204            **{
205                "index_offset": self.index_offset,
206                "unnest_column_only": self.unnest_column_only,
207                "alias_post_tablesample": self.alias_post_tablesample,
208                "null_ordering": self.null_ordering,
209                **opts,
210            },
211        )
212
213    def generator(self, **opts) -> Generator:
214        return self.generator_class(  # type: ignore
215            **{
216                "quote_start": self.quote_start,
217                "quote_end": self.quote_end,
218                "bit_start": self.bit_start,
219                "bit_end": self.bit_end,
220                "hex_start": self.hex_start,
221                "hex_end": self.hex_end,
222                "byte_start": self.byte_start,
223                "byte_end": self.byte_end,
224                "raw_start": self.raw_start,
225                "raw_end": self.raw_end,
226                "identifier_start": self.identifier_start,
227                "identifier_end": self.identifier_end,
228                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
229                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
230                "index_offset": self.index_offset,
231                "time_mapping": self.inverse_time_mapping,
232                "time_trie": self.inverse_time_trie,
233                "unnest_column_only": self.unnest_column_only,
234                "alias_post_tablesample": self.alias_post_tablesample,
235                "normalize_functions": self.normalize_functions,
236                "null_ordering": self.null_ordering,
237                **opts,
238            }
239        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
142    @classmethod
143    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
144        if not dialect:
145            return cls
146        if isinstance(dialect, _Dialect):
147            return dialect
148        if isinstance(dialect, Dialect):
149            return dialect.__class__
150
151        result = cls.get(dialect)
152        if not result:
153            raise ValueError(f"Unknown dialect '{dialect}'")
154
155        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
157    @classmethod
158    def format_time(
159        cls, expression: t.Optional[str | exp.Expression]
160    ) -> t.Optional[exp.Expression]:
161        if isinstance(expression, str):
162            return exp.Literal.string(
163                format_time(
164                    expression[1:-1],  # the time formats are quoted
165                    cls.time_mapping,
166                    cls.time_trie,
167                )
168            )
169        if expression and expression.is_string:
170            return exp.Literal.string(
171                format_time(
172                    expression.this,
173                    cls.time_mapping,
174                    cls.time_trie,
175                )
176            )
177        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
179    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
180        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
182    def parse_into(
183        self, expression_type: exp.IntoType, sql: str, **opts
184    ) -> t.List[t.Optional[exp.Expression]]:
185        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
187    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
188        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
190    def transpile(self, sql: str, **opts) -> t.List[str]:
191        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
193    def tokenize(self, sql: str) -> t.List[Token]:
194        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
202    def parser(self, **opts) -> Parser:
203        return self.parser_class(  # type: ignore
204            **{
205                "index_offset": self.index_offset,
206                "unnest_column_only": self.unnest_column_only,
207                "alias_post_tablesample": self.alias_post_tablesample,
208                "null_ordering": self.null_ordering,
209                **opts,
210            },
211        )
def generator(self, **opts) -> sqlglot.generator.Generator:
213    def generator(self, **opts) -> Generator:
214        return self.generator_class(  # type: ignore
215            **{
216                "quote_start": self.quote_start,
217                "quote_end": self.quote_end,
218                "bit_start": self.bit_start,
219                "bit_end": self.bit_end,
220                "hex_start": self.hex_start,
221                "hex_end": self.hex_end,
222                "byte_start": self.byte_start,
223                "byte_end": self.byte_end,
224                "raw_start": self.raw_start,
225                "raw_end": self.raw_end,
226                "identifier_start": self.identifier_start,
227                "identifier_end": self.identifier_end,
228                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
229                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
230                "index_offset": self.index_offset,
231                "time_mapping": self.inverse_time_mapping,
232                "time_trie": self.inverse_time_trie,
233                "unnest_column_only": self.unnest_column_only,
234                "alias_post_tablesample": self.alias_post_tablesample,
235                "normalize_functions": self.normalize_functions,
236                "null_ordering": self.null_ordering,
237                **opts,
238            }
239        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
245def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
246    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
249def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
250    if expression.args.get("accuracy"):
251        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
252    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
255def if_sql(self: Generator, expression: exp.If) -> str:
256    return self.func(
257        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
258    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
261def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
262    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
265def arrow_json_extract_scalar_sql(
266    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
267) -> str:
268    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
271def inline_array_sql(self: Generator, expression: exp.Array) -> str:
272    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
275def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
276    return self.like_sql(
277        exp.Like(
278            this=exp.Lower(this=expression.this),
279            expression=expression.args["expression"],
280        )
281    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
284def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
285    zone = self.sql(expression, "this")
286    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
289def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
290    if expression.args.get("recursive"):
291        self.unsupported("Recursive CTEs are unsupported")
292        expression.args["recursive"] = False
293    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
296def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
297    n = self.sql(expression, "this")
298    d = self.sql(expression, "expression")
299    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
302def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
303    self.unsupported("TABLESAMPLE unsupported")
304    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
307def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
308    self.unsupported("PIVOT unsupported")
309    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
312def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
313    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
316def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
317    self.unsupported("Properties unsupported")
318    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
321def no_comment_column_constraint_sql(
322    self: Generator, expression: exp.CommentColumnConstraint
323) -> str:
324    self.unsupported("CommentColumnConstraint unsupported")
325    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
328def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
329    this = self.sql(expression, "this")
330    substr = self.sql(expression, "substr")
331    position = self.sql(expression, "position")
332    if position:
333        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
334    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
337def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
338    this = self.sql(expression, "this")
339    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
340    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
343def var_map_sql(
344    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
345) -> str:
346    keys = expression.args["keys"]
347    values = expression.args["values"]
348
349    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
350        self.unsupported("Cannot convert array columns into map.")
351        return self.func(map_func_name, keys, values)
352
353    args = []
354    for key, value in zip(keys.expressions, values.expressions):
355        args.append(self.sql(key))
356        args.append(self.sql(value))
357    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[List], ~E]:
360def format_time_lambda(
361    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
362) -> t.Callable[[t.List], E]:
363    """Helper used for time expressions.
364
365    Args:
366        exp_class: the expression class to instantiate.
367        dialect: target sql dialect.
368        default: the default format, True being time.
369
370    Returns:
371        A callable that can be used to return the appropriately formatted time expression.
372    """
373
374    def _format_time(args: t.List):
375        return exp_class(
376            this=seq_get(args, 0),
377            format=Dialect[dialect].format_time(
378                seq_get(args, 1)
379                or (Dialect[dialect].time_format if default is True else default or None)
380            ),
381        )
382
383    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
386def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
387    """
388    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
389    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
390    columns are removed from the create statement.
391    """
392    has_schema = isinstance(expression.this, exp.Schema)
393    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
394
395    if has_schema and is_partitionable:
396        expression = expression.copy()
397        prop = expression.find(exp.PartitionedByProperty)
398        if prop and prop.this and not isinstance(prop.this, exp.Schema):
399            schema = expression.this
400            columns = {v.name.upper() for v in prop.this.expressions}
401            partitions = [col for col in schema.expressions if col.name.upper() in columns]
402            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
403            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
404            expression.set("this", schema)
405
406    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[List], ~E]:
409def parse_date_delta(
410    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
411) -> t.Callable[[t.List], E]:
412    def inner_func(args: t.List) -> E:
413        unit_based = len(args) == 3
414        this = args[2] if unit_based else seq_get(args, 0)
415        unit = args[0] if unit_based else exp.Literal.string("DAY")
416        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
417        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
418
419    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[List], Optional[~E]]:
422def parse_date_delta_with_interval(
423    expression_class: t.Type[E],
424) -> t.Callable[[t.List], t.Optional[E]]:
425    def func(args: t.List) -> t.Optional[E]:
426        if len(args) < 2:
427            return None
428
429        interval = args[1]
430        expression = interval.this
431        if expression and expression.is_string:
432            expression = exp.Literal.number(expression.this)
433
434        return expression_class(
435            this=args[0],
436            expression=expression,
437            unit=exp.Literal.string(interval.text("unit")),
438        )
439
440    return func
def date_trunc_to_time( args: List) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
443def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
444    unit = seq_get(args, 0)
445    this = seq_get(args, 1)
446
447    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
448        return exp.DateTrunc(unit=unit, this=this)
449    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
452def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
453    return self.func(
454        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
455    )
def locate_to_strposition(args: List) -> sqlglot.expressions.Expression:
458def locate_to_strposition(args: t.List) -> exp.Expression:
459    return exp.StrPosition(
460        this=seq_get(args, 1),
461        substr=seq_get(args, 0),
462        position=seq_get(args, 2),
463    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
466def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
467    return self.func(
468        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
469    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
472def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
473    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
476def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
477    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
480def min_or_least(self: Generator, expression: exp.Min) -> str:
481    name = "LEAST" if expression.expressions else "MIN"
482    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
485def max_or_greatest(self: Generator, expression: exp.Max) -> str:
486    name = "GREATEST" if expression.expressions else "MAX"
487    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
490def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
491    cond = expression.this
492
493    if isinstance(expression.this, exp.Distinct):
494        cond = expression.this.expressions[0]
495        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
496
497    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
500def trim_sql(self: Generator, expression: exp.Trim) -> str:
501    target = self.sql(expression, "this")
502    trim_type = self.sql(expression, "position")
503    remove_chars = self.sql(expression, "expression")
504    collation = self.sql(expression, "collation")
505
506    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
507    if not remove_chars and not collation:
508        return self.trim_sql(expression)
509
510    trim_type = f"{trim_type} " if trim_type else ""
511    remove_chars = f"{remove_chars} " if remove_chars else ""
512    from_part = "FROM " if trim_type or remove_chars else ""
513    collation = f" COLLATE {collation}" if collation else ""
514    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Expression) -> str:
517def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
518    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
521def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
522    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
523        _dialect = Dialect.get_or_raise(dialect)
524        time_format = self.format_time(expression)
525        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
526            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
527        return f"CAST({self.sql(expression, 'this')} AS DATE)"
528
529    return _ts_or_ds_to_date_sql
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
533def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
534    names = []
535    for agg in aggregations:
536        if isinstance(agg, exp.Alias):
537            names.append(agg.alias)
538        else:
539            """
540            This case corresponds to aggregations without aliases being used as suffixes
541            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
542            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
543            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
544            """
545            agg_all_unquoted = agg.transform(
546                lambda node: exp.Identifier(this=node.name, quoted=False)
547                if isinstance(node, exp.Identifier)
548                else node
549            )
550            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
551
552    return names