Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer
 12from sqlglot.trie import new_trie
 13
 14E = t.TypeVar("E", bound=exp.Expression)
 15
 16
 17class Dialects(str, Enum):
 18    DIALECT = ""
 19
 20    BIGQUERY = "bigquery"
 21    CLICKHOUSE = "clickhouse"
 22    DUCKDB = "duckdb"
 23    HIVE = "hive"
 24    MYSQL = "mysql"
 25    ORACLE = "oracle"
 26    POSTGRES = "postgres"
 27    PRESTO = "presto"
 28    REDSHIFT = "redshift"
 29    SNOWFLAKE = "snowflake"
 30    SPARK = "spark"
 31    SPARK2 = "spark2"
 32    SQLITE = "sqlite"
 33    STARROCKS = "starrocks"
 34    TABLEAU = "tableau"
 35    TRINO = "trino"
 36    TSQL = "tsql"
 37    DATABRICKS = "databricks"
 38    DRILL = "drill"
 39    TERADATA = "teradata"
 40
 41
 42class _Dialect(type):
 43    classes: t.Dict[str, t.Type[Dialect]] = {}
 44
 45    def __eq__(cls, other: t.Any) -> bool:
 46        if cls is other:
 47            return True
 48        if isinstance(other, str):
 49            return cls is cls.get(other)
 50        if isinstance(other, Dialect):
 51            return cls is type(other)
 52
 53        return False
 54
 55    def __hash__(cls) -> int:
 56        return hash(cls.__name__.lower())
 57
 58    @classmethod
 59    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 60        return cls.classes[key]
 61
 62    @classmethod
 63    def get(
 64        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 65    ) -> t.Optional[t.Type[Dialect]]:
 66        return cls.classes.get(key, default)
 67
 68    def __new__(cls, clsname, bases, attrs):
 69        klass = super().__new__(cls, clsname, bases, attrs)
 70        enum = Dialects.__members__.get(clsname.upper())
 71        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 72
 73        klass.time_trie = new_trie(klass.time_mapping)
 74        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 75        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 76
 77        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 78        klass.parser_class = getattr(klass, "Parser", Parser)
 79        klass.generator_class = getattr(klass, "Generator", Generator)
 80
 81        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 82        klass.identifier_start, klass.identifier_end = list(
 83            klass.tokenizer_class._IDENTIFIERS.items()
 84        )[0]
 85
 86        klass.bit_start, klass.bit_end = seq_get(
 87            list(klass.tokenizer_class._BIT_STRINGS.items()), 0
 88        ) or (None, None)
 89
 90        klass.hex_start, klass.hex_end = seq_get(
 91            list(klass.tokenizer_class._HEX_STRINGS.items()), 0
 92        ) or (None, None)
 93
 94        klass.byte_start, klass.byte_end = seq_get(
 95            list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
 96        ) or (None, None)
 97
 98        return klass
 99
100
101class Dialect(metaclass=_Dialect):
102    index_offset = 0
103    unnest_column_only = False
104    alias_post_tablesample = False
105    normalize_functions: t.Optional[str] = "upper"
106    null_ordering = "nulls_are_small"
107
108    date_format = "'%Y-%m-%d'"
109    dateint_format = "'%Y%m%d'"
110    time_format = "'%Y-%m-%d %H:%M:%S'"
111    time_mapping: t.Dict[str, str] = {}
112
113    # autofilled
114    quote_start = None
115    quote_end = None
116    identifier_start = None
117    identifier_end = None
118
119    time_trie = None
120    inverse_time_mapping = None
121    inverse_time_trie = None
122    tokenizer_class = None
123    parser_class = None
124    generator_class = None
125
126    def __eq__(self, other: t.Any) -> bool:
127        return type(self) == other
128
129    def __hash__(self) -> int:
130        return hash(type(self))
131
132    @classmethod
133    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
134        if not dialect:
135            return cls
136        if isinstance(dialect, _Dialect):
137            return dialect
138        if isinstance(dialect, Dialect):
139            return dialect.__class__
140
141        result = cls.get(dialect)
142        if not result:
143            raise ValueError(f"Unknown dialect '{dialect}'")
144
145        return result
146
147    @classmethod
148    def format_time(
149        cls, expression: t.Optional[str | exp.Expression]
150    ) -> t.Optional[exp.Expression]:
151        if isinstance(expression, str):
152            return exp.Literal.string(
153                format_time(
154                    expression[1:-1],  # the time formats are quoted
155                    cls.time_mapping,
156                    cls.time_trie,
157                )
158            )
159        if expression and expression.is_string:
160            return exp.Literal.string(
161                format_time(
162                    expression.this,
163                    cls.time_mapping,
164                    cls.time_trie,
165                )
166            )
167        return expression
168
169    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
170        return self.parser(**opts).parse(self.tokenize(sql), sql)
171
172    def parse_into(
173        self, expression_type: exp.IntoType, sql: str, **opts
174    ) -> t.List[t.Optional[exp.Expression]]:
175        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
176
177    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
178        return self.generator(**opts).generate(expression)
179
180    def transpile(self, sql: str, **opts) -> t.List[str]:
181        return [self.generate(expression, **opts) for expression in self.parse(sql)]
182
183    def tokenize(self, sql: str) -> t.List[Token]:
184        return self.tokenizer.tokenize(sql)
185
186    @property
187    def tokenizer(self) -> Tokenizer:
188        if not hasattr(self, "_tokenizer"):
189            self._tokenizer = self.tokenizer_class()  # type: ignore
190        return self._tokenizer
191
192    def parser(self, **opts) -> Parser:
193        return self.parser_class(  # type: ignore
194            **{
195                "index_offset": self.index_offset,
196                "unnest_column_only": self.unnest_column_only,
197                "alias_post_tablesample": self.alias_post_tablesample,
198                "null_ordering": self.null_ordering,
199                **opts,
200            },
201        )
202
203    def generator(self, **opts) -> Generator:
204        return self.generator_class(  # type: ignore
205            **{
206                "quote_start": self.quote_start,
207                "quote_end": self.quote_end,
208                "bit_start": self.bit_start,
209                "bit_end": self.bit_end,
210                "hex_start": self.hex_start,
211                "hex_end": self.hex_end,
212                "byte_start": self.byte_start,
213                "byte_end": self.byte_end,
214                "identifier_start": self.identifier_start,
215                "identifier_end": self.identifier_end,
216                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
217                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
218                "index_offset": self.index_offset,
219                "time_mapping": self.inverse_time_mapping,
220                "time_trie": self.inverse_time_trie,
221                "unnest_column_only": self.unnest_column_only,
222                "alias_post_tablesample": self.alias_post_tablesample,
223                "normalize_functions": self.normalize_functions,
224                "null_ordering": self.null_ordering,
225                **opts,
226            }
227        )
228
229
230DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
231
232
233def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
234    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
235
236
237def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
238    if expression.args.get("accuracy"):
239        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
240    return self.func("APPROX_COUNT_DISTINCT", expression.this)
241
242
243def if_sql(self: Generator, expression: exp.If) -> str:
244    return self.func(
245        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
246    )
247
248
249def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
250    return self.binary(expression, "->")
251
252
253def arrow_json_extract_scalar_sql(
254    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
255) -> str:
256    return self.binary(expression, "->>")
257
258
259def inline_array_sql(self: Generator, expression: exp.Array) -> str:
260    return f"[{self.expressions(expression)}]"
261
262
263def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
264    return self.like_sql(
265        exp.Like(
266            this=exp.Lower(this=expression.this),
267            expression=expression.args["expression"],
268        )
269    )
270
271
272def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
273    zone = self.sql(expression, "this")
274    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
275
276
277def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
278    if expression.args.get("recursive"):
279        self.unsupported("Recursive CTEs are unsupported")
280        expression.args["recursive"] = False
281    return self.with_sql(expression)
282
283
284def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
285    n = self.sql(expression, "this")
286    d = self.sql(expression, "expression")
287    return f"IF({d} <> 0, {n} / {d}, NULL)"
288
289
290def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
291    self.unsupported("TABLESAMPLE unsupported")
292    return self.sql(expression.this)
293
294
295def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
296    self.unsupported("PIVOT unsupported")
297    return ""
298
299
300def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
301    return self.cast_sql(expression)
302
303
304def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
305    self.unsupported("Properties unsupported")
306    return ""
307
308
309def no_comment_column_constraint_sql(
310    self: Generator, expression: exp.CommentColumnConstraint
311) -> str:
312    self.unsupported("CommentColumnConstraint unsupported")
313    return ""
314
315
316def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
317    this = self.sql(expression, "this")
318    substr = self.sql(expression, "substr")
319    position = self.sql(expression, "position")
320    if position:
321        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
322    return f"STRPOS({this}, {substr})"
323
324
325def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
326    this = self.sql(expression, "this")
327    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
328    return f"{this}.{struct_key}"
329
330
331def var_map_sql(
332    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
333) -> str:
334    keys = expression.args["keys"]
335    values = expression.args["values"]
336
337    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
338        self.unsupported("Cannot convert array columns into map.")
339        return self.func(map_func_name, keys, values)
340
341    args = []
342    for key, value in zip(keys.expressions, values.expressions):
343        args.append(self.sql(key))
344        args.append(self.sql(value))
345    return self.func(map_func_name, *args)
346
347
348def format_time_lambda(
349    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
350) -> t.Callable[[t.Sequence], E]:
351    """Helper used for time expressions.
352
353    Args:
354        exp_class: the expression class to instantiate.
355        dialect: target sql dialect.
356        default: the default format, True being time.
357
358    Returns:
359        A callable that can be used to return the appropriately formatted time expression.
360    """
361
362    def _format_time(args: t.Sequence):
363        return exp_class(
364            this=seq_get(args, 0),
365            format=Dialect[dialect].format_time(
366                seq_get(args, 1)
367                or (Dialect[dialect].time_format if default is True else default or None)
368            ),
369        )
370
371    return _format_time
372
373
374def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
375    """
376    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
377    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
378    columns are removed from the create statement.
379    """
380    has_schema = isinstance(expression.this, exp.Schema)
381    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
382
383    if has_schema and is_partitionable:
384        expression = expression.copy()
385        prop = expression.find(exp.PartitionedByProperty)
386        if prop and prop.this and not isinstance(prop.this, exp.Schema):
387            schema = expression.this
388            columns = {v.name.upper() for v in prop.this.expressions}
389            partitions = [col for col in schema.expressions if col.name.upper() in columns]
390            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
391            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
392            expression.set("this", schema)
393
394    return self.create_sql(expression)
395
396
397def parse_date_delta(
398    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
399) -> t.Callable[[t.Sequence], E]:
400    def inner_func(args: t.Sequence) -> E:
401        unit_based = len(args) == 3
402        this = args[2] if unit_based else seq_get(args, 0)
403        unit = args[0] if unit_based else exp.Literal.string("DAY")
404        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
405        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
406
407    return inner_func
408
409
410def parse_date_delta_with_interval(
411    expression_class: t.Type[E],
412) -> t.Callable[[t.Sequence], t.Optional[E]]:
413    def func(args: t.Sequence) -> t.Optional[E]:
414        if len(args) < 2:
415            return None
416
417        interval = args[1]
418        expression = interval.this
419        if expression and expression.is_string:
420            expression = exp.Literal.number(expression.this)
421
422        return expression_class(
423            this=args[0],
424            expression=expression,
425            unit=exp.Literal.string(interval.text("unit")),
426        )
427
428    return func
429
430
431def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
432    unit = seq_get(args, 0)
433    this = seq_get(args, 1)
434
435    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
436        return exp.DateTrunc(unit=unit, this=this)
437    return exp.TimestampTrunc(this=this, unit=unit)
438
439
440def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
441    return self.func(
442        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
443    )
444
445
446def locate_to_strposition(args: t.Sequence) -> exp.Expression:
447    return exp.StrPosition(
448        this=seq_get(args, 1),
449        substr=seq_get(args, 0),
450        position=seq_get(args, 2),
451    )
452
453
454def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
455    return self.func(
456        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
457    )
458
459
460def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
461    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
462
463
464def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
465    return f"CAST({self.sql(expression, 'this')} AS DATE)"
466
467
468def min_or_least(self: Generator, expression: exp.Min) -> str:
469    name = "LEAST" if expression.expressions else "MIN"
470    return rename_func(name)(self, expression)
471
472
473def max_or_greatest(self: Generator, expression: exp.Max) -> str:
474    name = "GREATEST" if expression.expressions else "MAX"
475    return rename_func(name)(self, expression)
476
477
478def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
479    cond = expression.this
480
481    if isinstance(expression.this, exp.Distinct):
482        cond = expression.this.expressions[0]
483        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
484
485    return self.func("sum", exp.func("if", cond, 1, 0))
486
487
488def trim_sql(self: Generator, expression: exp.Trim) -> str:
489    target = self.sql(expression, "this")
490    trim_type = self.sql(expression, "position")
491    remove_chars = self.sql(expression, "expression")
492    collation = self.sql(expression, "collation")
493
494    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
495    if not remove_chars and not collation:
496        return self.trim_sql(expression)
497
498    trim_type = f"{trim_type} " if trim_type else ""
499    remove_chars = f"{remove_chars} " if remove_chars else ""
500    from_part = "FROM " if trim_type or remove_chars else ""
501    collation = f" COLLATE {collation}" if collation else ""
502    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
503
504
505def str_to_time_sql(self, expression: exp.Expression) -> str:
506    return self.func("STRPTIME", expression.this, self.format_time(expression))
507
508
509def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
510    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
511        _dialect = Dialect.get_or_raise(dialect)
512        time_format = self.format_time(expression)
513        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
514            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
515        return f"CAST({self.sql(expression, 'this')} AS DATE)"
516
517    return _ts_or_ds_to_date_sql
518
519
520# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
521def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
522    names = []
523    for agg in aggregations:
524        if isinstance(agg, exp.Alias):
525            names.append(agg.alias)
526        else:
527            """
528            This case corresponds to aggregations without aliases being used as suffixes
529            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
530            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
531            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
532            """
533            agg_all_unquoted = agg.transform(
534                lambda node: exp.Identifier(this=node.name, quoted=False)
535                if isinstance(node, exp.Identifier)
536                else node
537            )
538            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
539
540    return names
class Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum):
19    DIALECT = ""
20
21    BIGQUERY = "bigquery"
22    CLICKHOUSE = "clickhouse"
23    DUCKDB = "duckdb"
24    HIVE = "hive"
25    MYSQL = "mysql"
26    ORACLE = "oracle"
27    POSTGRES = "postgres"
28    PRESTO = "presto"
29    REDSHIFT = "redshift"
30    SNOWFLAKE = "snowflake"
31    SPARK = "spark"
32    SPARK2 = "spark2"
33    SQLITE = "sqlite"
34    STARROCKS = "starrocks"
35    TABLEAU = "tableau"
36    TRINO = "trino"
37    TSQL = "tsql"
38    DATABRICKS = "databricks"
39    DRILL = "drill"
40    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
102class Dialect(metaclass=_Dialect):
103    index_offset = 0
104    unnest_column_only = False
105    alias_post_tablesample = False
106    normalize_functions: t.Optional[str] = "upper"
107    null_ordering = "nulls_are_small"
108
109    date_format = "'%Y-%m-%d'"
110    dateint_format = "'%Y%m%d'"
111    time_format = "'%Y-%m-%d %H:%M:%S'"
112    time_mapping: t.Dict[str, str] = {}
113
114    # autofilled
115    quote_start = None
116    quote_end = None
117    identifier_start = None
118    identifier_end = None
119
120    time_trie = None
121    inverse_time_mapping = None
122    inverse_time_trie = None
123    tokenizer_class = None
124    parser_class = None
125    generator_class = None
126
127    def __eq__(self, other: t.Any) -> bool:
128        return type(self) == other
129
130    def __hash__(self) -> int:
131        return hash(type(self))
132
133    @classmethod
134    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
135        if not dialect:
136            return cls
137        if isinstance(dialect, _Dialect):
138            return dialect
139        if isinstance(dialect, Dialect):
140            return dialect.__class__
141
142        result = cls.get(dialect)
143        if not result:
144            raise ValueError(f"Unknown dialect '{dialect}'")
145
146        return result
147
148    @classmethod
149    def format_time(
150        cls, expression: t.Optional[str | exp.Expression]
151    ) -> t.Optional[exp.Expression]:
152        if isinstance(expression, str):
153            return exp.Literal.string(
154                format_time(
155                    expression[1:-1],  # the time formats are quoted
156                    cls.time_mapping,
157                    cls.time_trie,
158                )
159            )
160        if expression and expression.is_string:
161            return exp.Literal.string(
162                format_time(
163                    expression.this,
164                    cls.time_mapping,
165                    cls.time_trie,
166                )
167            )
168        return expression
169
170    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
171        return self.parser(**opts).parse(self.tokenize(sql), sql)
172
173    def parse_into(
174        self, expression_type: exp.IntoType, sql: str, **opts
175    ) -> t.List[t.Optional[exp.Expression]]:
176        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
177
178    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
179        return self.generator(**opts).generate(expression)
180
181    def transpile(self, sql: str, **opts) -> t.List[str]:
182        return [self.generate(expression, **opts) for expression in self.parse(sql)]
183
184    def tokenize(self, sql: str) -> t.List[Token]:
185        return self.tokenizer.tokenize(sql)
186
187    @property
188    def tokenizer(self) -> Tokenizer:
189        if not hasattr(self, "_tokenizer"):
190            self._tokenizer = self.tokenizer_class()  # type: ignore
191        return self._tokenizer
192
193    def parser(self, **opts) -> Parser:
194        return self.parser_class(  # type: ignore
195            **{
196                "index_offset": self.index_offset,
197                "unnest_column_only": self.unnest_column_only,
198                "alias_post_tablesample": self.alias_post_tablesample,
199                "null_ordering": self.null_ordering,
200                **opts,
201            },
202        )
203
204    def generator(self, **opts) -> Generator:
205        return self.generator_class(  # type: ignore
206            **{
207                "quote_start": self.quote_start,
208                "quote_end": self.quote_end,
209                "bit_start": self.bit_start,
210                "bit_end": self.bit_end,
211                "hex_start": self.hex_start,
212                "hex_end": self.hex_end,
213                "byte_start": self.byte_start,
214                "byte_end": self.byte_end,
215                "identifier_start": self.identifier_start,
216                "identifier_end": self.identifier_end,
217                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
218                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
219                "index_offset": self.index_offset,
220                "time_mapping": self.inverse_time_mapping,
221                "time_trie": self.inverse_time_trie,
222                "unnest_column_only": self.unnest_column_only,
223                "alias_post_tablesample": self.alias_post_tablesample,
224                "normalize_functions": self.normalize_functions,
225                "null_ordering": self.null_ordering,
226                **opts,
227            }
228        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
133    @classmethod
134    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
135        if not dialect:
136            return cls
137        if isinstance(dialect, _Dialect):
138            return dialect
139        if isinstance(dialect, Dialect):
140            return dialect.__class__
141
142        result = cls.get(dialect)
143        if not result:
144            raise ValueError(f"Unknown dialect '{dialect}'")
145
146        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
148    @classmethod
149    def format_time(
150        cls, expression: t.Optional[str | exp.Expression]
151    ) -> t.Optional[exp.Expression]:
152        if isinstance(expression, str):
153            return exp.Literal.string(
154                format_time(
155                    expression[1:-1],  # the time formats are quoted
156                    cls.time_mapping,
157                    cls.time_trie,
158                )
159            )
160        if expression and expression.is_string:
161            return exp.Literal.string(
162                format_time(
163                    expression.this,
164                    cls.time_mapping,
165                    cls.time_trie,
166                )
167            )
168        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
170    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
171        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
173    def parse_into(
174        self, expression_type: exp.IntoType, sql: str, **opts
175    ) -> t.List[t.Optional[exp.Expression]]:
176        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
178    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
179        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
181    def transpile(self, sql: str, **opts) -> t.List[str]:
182        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
184    def tokenize(self, sql: str) -> t.List[Token]:
185        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
193    def parser(self, **opts) -> Parser:
194        return self.parser_class(  # type: ignore
195            **{
196                "index_offset": self.index_offset,
197                "unnest_column_only": self.unnest_column_only,
198                "alias_post_tablesample": self.alias_post_tablesample,
199                "null_ordering": self.null_ordering,
200                **opts,
201            },
202        )
def generator(self, **opts) -> sqlglot.generator.Generator:
204    def generator(self, **opts) -> Generator:
205        return self.generator_class(  # type: ignore
206            **{
207                "quote_start": self.quote_start,
208                "quote_end": self.quote_end,
209                "bit_start": self.bit_start,
210                "bit_end": self.bit_end,
211                "hex_start": self.hex_start,
212                "hex_end": self.hex_end,
213                "byte_start": self.byte_start,
214                "byte_end": self.byte_end,
215                "identifier_start": self.identifier_start,
216                "identifier_end": self.identifier_end,
217                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
218                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
219                "index_offset": self.index_offset,
220                "time_mapping": self.inverse_time_mapping,
221                "time_trie": self.inverse_time_trie,
222                "unnest_column_only": self.unnest_column_only,
223                "alias_post_tablesample": self.alias_post_tablesample,
224                "normalize_functions": self.normalize_functions,
225                "null_ordering": self.null_ordering,
226                **opts,
227            }
228        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
234def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
235    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
238def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
239    if expression.args.get("accuracy"):
240        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
241    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
244def if_sql(self: Generator, expression: exp.If) -> str:
245    return self.func(
246        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
247    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
250def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
251    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
254def arrow_json_extract_scalar_sql(
255    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
256) -> str:
257    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
260def inline_array_sql(self: Generator, expression: exp.Array) -> str:
261    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
264def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
265    return self.like_sql(
266        exp.Like(
267            this=exp.Lower(this=expression.this),
268            expression=expression.args["expression"],
269        )
270    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
273def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
274    zone = self.sql(expression, "this")
275    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
278def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
279    if expression.args.get("recursive"):
280        self.unsupported("Recursive CTEs are unsupported")
281        expression.args["recursive"] = False
282    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
285def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
286    n = self.sql(expression, "this")
287    d = self.sql(expression, "expression")
288    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
291def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
292    self.unsupported("TABLESAMPLE unsupported")
293    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
296def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
297    self.unsupported("PIVOT unsupported")
298    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
301def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
302    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
305def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
306    self.unsupported("Properties unsupported")
307    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
310def no_comment_column_constraint_sql(
311    self: Generator, expression: exp.CommentColumnConstraint
312) -> str:
313    self.unsupported("CommentColumnConstraint unsupported")
314    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
317def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
318    this = self.sql(expression, "this")
319    substr = self.sql(expression, "substr")
320    position = self.sql(expression, "position")
321    if position:
322        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
323    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
326def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
327    this = self.sql(expression, "this")
328    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
329    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
332def var_map_sql(
333    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
334) -> str:
335    keys = expression.args["keys"]
336    values = expression.args["values"]
337
338    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
339        self.unsupported("Cannot convert array columns into map.")
340        return self.func(map_func_name, keys, values)
341
342    args = []
343    for key, value in zip(keys.expressions, values.expressions):
344        args.append(self.sql(key))
345        args.append(self.sql(value))
346    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
349def format_time_lambda(
350    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
351) -> t.Callable[[t.Sequence], E]:
352    """Helper used for time expressions.
353
354    Args:
355        exp_class: the expression class to instantiate.
356        dialect: target sql dialect.
357        default: the default format, True being time.
358
359    Returns:
360        A callable that can be used to return the appropriately formatted time expression.
361    """
362
363    def _format_time(args: t.Sequence):
364        return exp_class(
365            this=seq_get(args, 0),
366            format=Dialect[dialect].format_time(
367                seq_get(args, 1)
368                or (Dialect[dialect].time_format if default is True else default or None)
369            ),
370        )
371
372    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
375def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
376    """
377    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
378    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
379    columns are removed from the create statement.
380    """
381    has_schema = isinstance(expression.this, exp.Schema)
382    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
383
384    if has_schema and is_partitionable:
385        expression = expression.copy()
386        prop = expression.find(exp.PartitionedByProperty)
387        if prop and prop.this and not isinstance(prop.this, exp.Schema):
388            schema = expression.this
389            columns = {v.name.upper() for v in prop.this.expressions}
390            partitions = [col for col in schema.expressions if col.name.upper() in columns]
391            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
392            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
393            expression.set("this", schema)
394
395    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
398def parse_date_delta(
399    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
400) -> t.Callable[[t.Sequence], E]:
401    def inner_func(args: t.Sequence) -> E:
402        unit_based = len(args) == 3
403        this = args[2] if unit_based else seq_get(args, 0)
404        unit = args[0] if unit_based else exp.Literal.string("DAY")
405        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
406        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
407
408    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
411def parse_date_delta_with_interval(
412    expression_class: t.Type[E],
413) -> t.Callable[[t.Sequence], t.Optional[E]]:
414    def func(args: t.Sequence) -> t.Optional[E]:
415        if len(args) < 2:
416            return None
417
418        interval = args[1]
419        expression = interval.this
420        if expression and expression.is_string:
421            expression = exp.Literal.number(expression.this)
422
423        return expression_class(
424            this=args[0],
425            expression=expression,
426            unit=exp.Literal.string(interval.text("unit")),
427        )
428
429    return func
def date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
432def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
433    unit = seq_get(args, 0)
434    this = seq_get(args, 1)
435
436    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
437        return exp.DateTrunc(unit=unit, this=this)
438    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
441def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
442    return self.func(
443        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
444    )
def locate_to_strposition(args: Sequence) -> sqlglot.expressions.Expression:
447def locate_to_strposition(args: t.Sequence) -> exp.Expression:
448    return exp.StrPosition(
449        this=seq_get(args, 1),
450        substr=seq_get(args, 0),
451        position=seq_get(args, 2),
452    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
455def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
456    return self.func(
457        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
458    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
461def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
462    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
465def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
466    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
469def min_or_least(self: Generator, expression: exp.Min) -> str:
470    name = "LEAST" if expression.expressions else "MIN"
471    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
474def max_or_greatest(self: Generator, expression: exp.Max) -> str:
475    name = "GREATEST" if expression.expressions else "MAX"
476    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
479def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
480    cond = expression.this
481
482    if isinstance(expression.this, exp.Distinct):
483        cond = expression.this.expressions[0]
484        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
485
486    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
489def trim_sql(self: Generator, expression: exp.Trim) -> str:
490    target = self.sql(expression, "this")
491    trim_type = self.sql(expression, "position")
492    remove_chars = self.sql(expression, "expression")
493    collation = self.sql(expression, "collation")
494
495    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
496    if not remove_chars and not collation:
497        return self.trim_sql(expression)
498
499    trim_type = f"{trim_type} " if trim_type else ""
500    remove_chars = f"{remove_chars} " if remove_chars else ""
501    from_part = "FROM " if trim_type or remove_chars else ""
502    collation = f" COLLATE {collation}" if collation else ""
503    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: sqlglot.expressions.Expression) -> str:
506def str_to_time_sql(self, expression: exp.Expression) -> str:
507    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
510def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
511    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
512        _dialect = Dialect.get_or_raise(dialect)
513        time_format = self.format_time(expression)
514        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
515            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
516        return f"CAST({self.sql(expression, 'this')} AS DATE)"
517
518    return _ts_or_ds_to_date_sql
def pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
522def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
523    names = []
524    for agg in aggregations:
525        if isinstance(agg, exp.Alias):
526            names.append(agg.alias)
527        else:
528            """
529            This case corresponds to aggregations without aliases being used as suffixes
530            (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
531            be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
532            Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
533            """
534            agg_all_unquoted = agg.transform(
535                lambda node: exp.Identifier(this=node.name, quoted=False)
536                if isinstance(node, exp.Identifier)
537                else node
538            )
539            names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
540
541    return names