Edit on GitHub

sqlglot.dialects.dialect

  1from __future__ import annotations
  2
  3import typing as t
  4from enum import Enum
  5
  6from sqlglot import exp
  7from sqlglot.generator import Generator
  8from sqlglot.helper import flatten, seq_get
  9from sqlglot.parser import Parser
 10from sqlglot.time import format_time
 11from sqlglot.tokens import Token, Tokenizer
 12from sqlglot.trie import new_trie
 13
 14E = t.TypeVar("E", bound=exp.Expression)
 15
 16
 17class Dialects(str, Enum):
 18    DIALECT = ""
 19
 20    BIGQUERY = "bigquery"
 21    CLICKHOUSE = "clickhouse"
 22    DUCKDB = "duckdb"
 23    HIVE = "hive"
 24    MYSQL = "mysql"
 25    ORACLE = "oracle"
 26    POSTGRES = "postgres"
 27    PRESTO = "presto"
 28    REDSHIFT = "redshift"
 29    SNOWFLAKE = "snowflake"
 30    SPARK = "spark"
 31    SPARK2 = "spark2"
 32    SQLITE = "sqlite"
 33    STARROCKS = "starrocks"
 34    TABLEAU = "tableau"
 35    TRINO = "trino"
 36    TSQL = "tsql"
 37    DATABRICKS = "databricks"
 38    DRILL = "drill"
 39    TERADATA = "teradata"
 40
 41
 42class _Dialect(type):
 43    classes: t.Dict[str, t.Type[Dialect]] = {}
 44
 45    @classmethod
 46    def __getitem__(cls, key: str) -> t.Type[Dialect]:
 47        return cls.classes[key]
 48
 49    @classmethod
 50    def get(
 51        cls, key: str, default: t.Optional[t.Type[Dialect]] = None
 52    ) -> t.Optional[t.Type[Dialect]]:
 53        return cls.classes.get(key, default)
 54
 55    def __new__(cls, clsname, bases, attrs):
 56        klass = super().__new__(cls, clsname, bases, attrs)
 57        enum = Dialects.__members__.get(clsname.upper())
 58        cls.classes[enum.value if enum is not None else clsname.lower()] = klass
 59
 60        klass.time_trie = new_trie(klass.time_mapping)
 61        klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
 62        klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
 63
 64        klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
 65        klass.parser_class = getattr(klass, "Parser", Parser)
 66        klass.generator_class = getattr(klass, "Generator", Generator)
 67
 68        klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
 69        klass.identifier_start, klass.identifier_end = list(
 70            klass.tokenizer_class._IDENTIFIERS.items()
 71        )[0]
 72
 73        klass.bit_start, klass.bit_end = seq_get(
 74            list(klass.tokenizer_class._BIT_STRINGS.items()), 0
 75        ) or (None, None)
 76
 77        klass.hex_start, klass.hex_end = seq_get(
 78            list(klass.tokenizer_class._HEX_STRINGS.items()), 0
 79        ) or (None, None)
 80
 81        klass.byte_start, klass.byte_end = seq_get(
 82            list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
 83        ) or (None, None)
 84
 85        return klass
 86
 87
 88class Dialect(metaclass=_Dialect):
 89    index_offset = 0
 90    unnest_column_only = False
 91    alias_post_tablesample = False
 92    normalize_functions: t.Optional[str] = "upper"
 93    null_ordering = "nulls_are_small"
 94
 95    date_format = "'%Y-%m-%d'"
 96    dateint_format = "'%Y%m%d'"
 97    time_format = "'%Y-%m-%d %H:%M:%S'"
 98    time_mapping: t.Dict[str, str] = {}
 99
100    # autofilled
101    quote_start = None
102    quote_end = None
103    identifier_start = None
104    identifier_end = None
105
106    time_trie = None
107    inverse_time_mapping = None
108    inverse_time_trie = None
109    tokenizer_class = None
110    parser_class = None
111    generator_class = None
112
113    @classmethod
114    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
115        if not dialect:
116            return cls
117        if isinstance(dialect, _Dialect):
118            return dialect
119        if isinstance(dialect, Dialect):
120            return dialect.__class__
121
122        result = cls.get(dialect)
123        if not result:
124            raise ValueError(f"Unknown dialect '{dialect}'")
125
126        return result
127
128    @classmethod
129    def format_time(
130        cls, expression: t.Optional[str | exp.Expression]
131    ) -> t.Optional[exp.Expression]:
132        if isinstance(expression, str):
133            return exp.Literal.string(
134                format_time(
135                    expression[1:-1],  # the time formats are quoted
136                    cls.time_mapping,
137                    cls.time_trie,
138                )
139            )
140        if expression and expression.is_string:
141            return exp.Literal.string(
142                format_time(
143                    expression.this,
144                    cls.time_mapping,
145                    cls.time_trie,
146                )
147            )
148        return expression
149
150    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
151        return self.parser(**opts).parse(self.tokenize(sql), sql)
152
153    def parse_into(
154        self, expression_type: exp.IntoType, sql: str, **opts
155    ) -> t.List[t.Optional[exp.Expression]]:
156        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
157
158    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
159        return self.generator(**opts).generate(expression)
160
161    def transpile(self, sql: str, **opts) -> t.List[str]:
162        return [self.generate(expression, **opts) for expression in self.parse(sql)]
163
164    def tokenize(self, sql: str) -> t.List[Token]:
165        return self.tokenizer.tokenize(sql)
166
167    @property
168    def tokenizer(self) -> Tokenizer:
169        if not hasattr(self, "_tokenizer"):
170            self._tokenizer = self.tokenizer_class()  # type: ignore
171        return self._tokenizer
172
173    def parser(self, **opts) -> Parser:
174        return self.parser_class(  # type: ignore
175            **{
176                "index_offset": self.index_offset,
177                "unnest_column_only": self.unnest_column_only,
178                "alias_post_tablesample": self.alias_post_tablesample,
179                "null_ordering": self.null_ordering,
180                **opts,
181            },
182        )
183
184    def generator(self, **opts) -> Generator:
185        return self.generator_class(  # type: ignore
186            **{
187                "quote_start": self.quote_start,
188                "quote_end": self.quote_end,
189                "bit_start": self.bit_start,
190                "bit_end": self.bit_end,
191                "hex_start": self.hex_start,
192                "hex_end": self.hex_end,
193                "byte_start": self.byte_start,
194                "byte_end": self.byte_end,
195                "identifier_start": self.identifier_start,
196                "identifier_end": self.identifier_end,
197                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
198                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
199                "index_offset": self.index_offset,
200                "time_mapping": self.inverse_time_mapping,
201                "time_trie": self.inverse_time_trie,
202                "unnest_column_only": self.unnest_column_only,
203                "alias_post_tablesample": self.alias_post_tablesample,
204                "normalize_functions": self.normalize_functions,
205                "null_ordering": self.null_ordering,
206                **opts,
207            }
208        )
209
210
211DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
212
213
214def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
215    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
216
217
218def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
219    if expression.args.get("accuracy"):
220        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
221    return self.func("APPROX_COUNT_DISTINCT", expression.this)
222
223
224def if_sql(self: Generator, expression: exp.If) -> str:
225    return self.func(
226        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
227    )
228
229
230def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
231    return self.binary(expression, "->")
232
233
234def arrow_json_extract_scalar_sql(
235    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
236) -> str:
237    return self.binary(expression, "->>")
238
239
240def inline_array_sql(self: Generator, expression: exp.Array) -> str:
241    return f"[{self.expressions(expression)}]"
242
243
244def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
245    return self.like_sql(
246        exp.Like(
247            this=exp.Lower(this=expression.this),
248            expression=expression.args["expression"],
249        )
250    )
251
252
253def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
254    zone = self.sql(expression, "this")
255    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
256
257
258def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
259    if expression.args.get("recursive"):
260        self.unsupported("Recursive CTEs are unsupported")
261        expression.args["recursive"] = False
262    return self.with_sql(expression)
263
264
265def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
266    n = self.sql(expression, "this")
267    d = self.sql(expression, "expression")
268    return f"IF({d} <> 0, {n} / {d}, NULL)"
269
270
271def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
272    self.unsupported("TABLESAMPLE unsupported")
273    return self.sql(expression.this)
274
275
276def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
277    self.unsupported("PIVOT unsupported")
278    return ""
279
280
281def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
282    return self.cast_sql(expression)
283
284
285def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
286    self.unsupported("Properties unsupported")
287    return ""
288
289
290def no_comment_column_constraint_sql(
291    self: Generator, expression: exp.CommentColumnConstraint
292) -> str:
293    self.unsupported("CommentColumnConstraint unsupported")
294    return ""
295
296
297def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
298    this = self.sql(expression, "this")
299    substr = self.sql(expression, "substr")
300    position = self.sql(expression, "position")
301    if position:
302        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
303    return f"STRPOS({this}, {substr})"
304
305
306def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
307    this = self.sql(expression, "this")
308    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
309    return f"{this}.{struct_key}"
310
311
312def var_map_sql(
313    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
314) -> str:
315    keys = expression.args["keys"]
316    values = expression.args["values"]
317
318    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
319        self.unsupported("Cannot convert array columns into map.")
320        return self.func(map_func_name, keys, values)
321
322    args = []
323    for key, value in zip(keys.expressions, values.expressions):
324        args.append(self.sql(key))
325        args.append(self.sql(value))
326    return self.func(map_func_name, *args)
327
328
329def format_time_lambda(
330    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
331) -> t.Callable[[t.Sequence], E]:
332    """Helper used for time expressions.
333
334    Args:
335        exp_class: the expression class to instantiate.
336        dialect: target sql dialect.
337        default: the default format, True being time.
338
339    Returns:
340        A callable that can be used to return the appropriately formatted time expression.
341    """
342
343    def _format_time(args: t.Sequence):
344        return exp_class(
345            this=seq_get(args, 0),
346            format=Dialect[dialect].format_time(
347                seq_get(args, 1)
348                or (Dialect[dialect].time_format if default is True else default or None)
349            ),
350        )
351
352    return _format_time
353
354
355def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
356    """
357    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
358    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
359    columns are removed from the create statement.
360    """
361    has_schema = isinstance(expression.this, exp.Schema)
362    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
363
364    if has_schema and is_partitionable:
365        expression = expression.copy()
366        prop = expression.find(exp.PartitionedByProperty)
367        if prop and prop.this and not isinstance(prop.this, exp.Schema):
368            schema = expression.this
369            columns = {v.name.upper() for v in prop.this.expressions}
370            partitions = [col for col in schema.expressions if col.name.upper() in columns]
371            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
372            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
373            expression.set("this", schema)
374
375    return self.create_sql(expression)
376
377
378def parse_date_delta(
379    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
380) -> t.Callable[[t.Sequence], E]:
381    def inner_func(args: t.Sequence) -> E:
382        unit_based = len(args) == 3
383        this = args[2] if unit_based else seq_get(args, 0)
384        unit = args[0] if unit_based else exp.Literal.string("DAY")
385        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
386        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
387
388    return inner_func
389
390
391def parse_date_delta_with_interval(
392    expression_class: t.Type[E],
393) -> t.Callable[[t.Sequence], t.Optional[E]]:
394    def func(args: t.Sequence) -> t.Optional[E]:
395        if len(args) < 2:
396            return None
397
398        interval = args[1]
399        expression = interval.this
400        if expression and expression.is_string:
401            expression = exp.Literal.number(expression.this)
402
403        return expression_class(
404            this=args[0],
405            expression=expression,
406            unit=exp.Literal.string(interval.text("unit")),
407        )
408
409    return func
410
411
412def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
413    unit = seq_get(args, 0)
414    this = seq_get(args, 1)
415
416    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
417        return exp.DateTrunc(unit=unit, this=this)
418    return exp.TimestampTrunc(this=this, unit=unit)
419
420
421def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
422    return self.func(
423        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
424    )
425
426
427def locate_to_strposition(args: t.Sequence) -> exp.Expression:
428    return exp.StrPosition(
429        this=seq_get(args, 1),
430        substr=seq_get(args, 0),
431        position=seq_get(args, 2),
432    )
433
434
435def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
436    return self.func(
437        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
438    )
439
440
441def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
442    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
443
444
445def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
446    return f"CAST({self.sql(expression, 'this')} AS DATE)"
447
448
449def min_or_least(self: Generator, expression: exp.Min) -> str:
450    name = "LEAST" if expression.expressions else "MIN"
451    return rename_func(name)(self, expression)
452
453
454def max_or_greatest(self: Generator, expression: exp.Max) -> str:
455    name = "GREATEST" if expression.expressions else "MAX"
456    return rename_func(name)(self, expression)
457
458
459def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
460    cond = expression.this
461
462    if isinstance(expression.this, exp.Distinct):
463        cond = expression.this.expressions[0]
464        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
465
466    return self.func("sum", exp.func("if", cond, 1, 0))
467
468
469def trim_sql(self: Generator, expression: exp.Trim) -> str:
470    target = self.sql(expression, "this")
471    trim_type = self.sql(expression, "position")
472    remove_chars = self.sql(expression, "expression")
473    collation = self.sql(expression, "collation")
474
475    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
476    if not remove_chars and not collation:
477        return self.trim_sql(expression)
478
479    trim_type = f"{trim_type} " if trim_type else ""
480    remove_chars = f"{remove_chars} " if remove_chars else ""
481    from_part = "FROM " if trim_type or remove_chars else ""
482    collation = f" COLLATE {collation}" if collation else ""
483    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
484
485
486def str_to_time_sql(self, expression: exp.Expression) -> str:
487    return self.func("STRPTIME", expression.this, self.format_time(expression))
488
489
490def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
491    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
492        _dialect = Dialect.get_or_raise(dialect)
493        time_format = self.format_time(expression)
494        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
495            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
496        return f"CAST({self.sql(expression, 'this')} AS DATE)"
497
498    return _ts_or_ds_to_date_sql
class Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum):
19    DIALECT = ""
20
21    BIGQUERY = "bigquery"
22    CLICKHOUSE = "clickhouse"
23    DUCKDB = "duckdb"
24    HIVE = "hive"
25    MYSQL = "mysql"
26    ORACLE = "oracle"
27    POSTGRES = "postgres"
28    PRESTO = "presto"
29    REDSHIFT = "redshift"
30    SNOWFLAKE = "snowflake"
31    SPARK = "spark"
32    SPARK2 = "spark2"
33    SQLITE = "sqlite"
34    STARROCKS = "starrocks"
35    TABLEAU = "tableau"
36    TRINO = "trino"
37    TSQL = "tsql"
38    DATABRICKS = "databricks"
39    DRILL = "drill"
40    TERADATA = "teradata"

An enumeration.

DIALECT = <Dialects.DIALECT: ''>
BIGQUERY = <Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE = <Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB = <Dialects.DUCKDB: 'duckdb'>
HIVE = <Dialects.HIVE: 'hive'>
MYSQL = <Dialects.MYSQL: 'mysql'>
ORACLE = <Dialects.ORACLE: 'oracle'>
POSTGRES = <Dialects.POSTGRES: 'postgres'>
PRESTO = <Dialects.PRESTO: 'presto'>
REDSHIFT = <Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE = <Dialects.SNOWFLAKE: 'snowflake'>
SPARK = <Dialects.SPARK: 'spark'>
SPARK2 = <Dialects.SPARK2: 'spark2'>
SQLITE = <Dialects.SQLITE: 'sqlite'>
STARROCKS = <Dialects.STARROCKS: 'starrocks'>
TABLEAU = <Dialects.TABLEAU: 'tableau'>
TRINO = <Dialects.TRINO: 'trino'>
TSQL = <Dialects.TSQL: 'tsql'>
DATABRICKS = <Dialects.DATABRICKS: 'databricks'>
DRILL = <Dialects.DRILL: 'drill'>
TERADATA = <Dialects.TERADATA: 'teradata'>
Inherited Members
enum.Enum
name
value
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
class Dialect:
 89class Dialect(metaclass=_Dialect):
 90    index_offset = 0
 91    unnest_column_only = False
 92    alias_post_tablesample = False
 93    normalize_functions: t.Optional[str] = "upper"
 94    null_ordering = "nulls_are_small"
 95
 96    date_format = "'%Y-%m-%d'"
 97    dateint_format = "'%Y%m%d'"
 98    time_format = "'%Y-%m-%d %H:%M:%S'"
 99    time_mapping: t.Dict[str, str] = {}
100
101    # autofilled
102    quote_start = None
103    quote_end = None
104    identifier_start = None
105    identifier_end = None
106
107    time_trie = None
108    inverse_time_mapping = None
109    inverse_time_trie = None
110    tokenizer_class = None
111    parser_class = None
112    generator_class = None
113
114    @classmethod
115    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
116        if not dialect:
117            return cls
118        if isinstance(dialect, _Dialect):
119            return dialect
120        if isinstance(dialect, Dialect):
121            return dialect.__class__
122
123        result = cls.get(dialect)
124        if not result:
125            raise ValueError(f"Unknown dialect '{dialect}'")
126
127        return result
128
129    @classmethod
130    def format_time(
131        cls, expression: t.Optional[str | exp.Expression]
132    ) -> t.Optional[exp.Expression]:
133        if isinstance(expression, str):
134            return exp.Literal.string(
135                format_time(
136                    expression[1:-1],  # the time formats are quoted
137                    cls.time_mapping,
138                    cls.time_trie,
139                )
140            )
141        if expression and expression.is_string:
142            return exp.Literal.string(
143                format_time(
144                    expression.this,
145                    cls.time_mapping,
146                    cls.time_trie,
147                )
148            )
149        return expression
150
151    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
152        return self.parser(**opts).parse(self.tokenize(sql), sql)
153
154    def parse_into(
155        self, expression_type: exp.IntoType, sql: str, **opts
156    ) -> t.List[t.Optional[exp.Expression]]:
157        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
158
159    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
160        return self.generator(**opts).generate(expression)
161
162    def transpile(self, sql: str, **opts) -> t.List[str]:
163        return [self.generate(expression, **opts) for expression in self.parse(sql)]
164
165    def tokenize(self, sql: str) -> t.List[Token]:
166        return self.tokenizer.tokenize(sql)
167
168    @property
169    def tokenizer(self) -> Tokenizer:
170        if not hasattr(self, "_tokenizer"):
171            self._tokenizer = self.tokenizer_class()  # type: ignore
172        return self._tokenizer
173
174    def parser(self, **opts) -> Parser:
175        return self.parser_class(  # type: ignore
176            **{
177                "index_offset": self.index_offset,
178                "unnest_column_only": self.unnest_column_only,
179                "alias_post_tablesample": self.alias_post_tablesample,
180                "null_ordering": self.null_ordering,
181                **opts,
182            },
183        )
184
185    def generator(self, **opts) -> Generator:
186        return self.generator_class(  # type: ignore
187            **{
188                "quote_start": self.quote_start,
189                "quote_end": self.quote_end,
190                "bit_start": self.bit_start,
191                "bit_end": self.bit_end,
192                "hex_start": self.hex_start,
193                "hex_end": self.hex_end,
194                "byte_start": self.byte_start,
195                "byte_end": self.byte_end,
196                "identifier_start": self.identifier_start,
197                "identifier_end": self.identifier_end,
198                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
199                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
200                "index_offset": self.index_offset,
201                "time_mapping": self.inverse_time_mapping,
202                "time_trie": self.inverse_time_trie,
203                "unnest_column_only": self.unnest_column_only,
204                "alias_post_tablesample": self.alias_post_tablesample,
205                "normalize_functions": self.normalize_functions,
206                "null_ordering": self.null_ordering,
207                **opts,
208            }
209        )
@classmethod
def get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
114    @classmethod
115    def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
116        if not dialect:
117            return cls
118        if isinstance(dialect, _Dialect):
119            return dialect
120        if isinstance(dialect, Dialect):
121            return dialect.__class__
122
123        result = cls.get(dialect)
124        if not result:
125            raise ValueError(f"Unknown dialect '{dialect}'")
126
127        return result
@classmethod
def format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
129    @classmethod
130    def format_time(
131        cls, expression: t.Optional[str | exp.Expression]
132    ) -> t.Optional[exp.Expression]:
133        if isinstance(expression, str):
134            return exp.Literal.string(
135                format_time(
136                    expression[1:-1],  # the time formats are quoted
137                    cls.time_mapping,
138                    cls.time_trie,
139                )
140            )
141        if expression and expression.is_string:
142            return exp.Literal.string(
143                format_time(
144                    expression.this,
145                    cls.time_mapping,
146                    cls.time_trie,
147                )
148            )
149        return expression
def parse(self, sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
151    def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
152        return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
154    def parse_into(
155        self, expression_type: exp.IntoType, sql: str, **opts
156    ) -> t.List[t.Optional[exp.Expression]]:
157        return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate( self, expression: Optional[sqlglot.expressions.Expression], **opts) -> str:
159    def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
160        return self.generator(**opts).generate(expression)
def transpile(self, sql: str, **opts) -> List[str]:
162    def transpile(self, sql: str, **opts) -> t.List[str]:
163        return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> List[sqlglot.tokens.Token]:
165    def tokenize(self, sql: str) -> t.List[Token]:
166        return self.tokenizer.tokenize(sql)
def parser(self, **opts) -> sqlglot.parser.Parser:
174    def parser(self, **opts) -> Parser:
175        return self.parser_class(  # type: ignore
176            **{
177                "index_offset": self.index_offset,
178                "unnest_column_only": self.unnest_column_only,
179                "alias_post_tablesample": self.alias_post_tablesample,
180                "null_ordering": self.null_ordering,
181                **opts,
182            },
183        )
def generator(self, **opts) -> sqlglot.generator.Generator:
185    def generator(self, **opts) -> Generator:
186        return self.generator_class(  # type: ignore
187            **{
188                "quote_start": self.quote_start,
189                "quote_end": self.quote_end,
190                "bit_start": self.bit_start,
191                "bit_end": self.bit_end,
192                "hex_start": self.hex_start,
193                "hex_end": self.hex_end,
194                "byte_start": self.byte_start,
195                "byte_end": self.byte_end,
196                "identifier_start": self.identifier_start,
197                "identifier_end": self.identifier_end,
198                "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
199                "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
200                "index_offset": self.index_offset,
201                "time_mapping": self.inverse_time_mapping,
202                "time_trie": self.inverse_time_trie,
203                "unnest_column_only": self.unnest_column_only,
204                "alias_post_tablesample": self.alias_post_tablesample,
205                "normalize_functions": self.normalize_functions,
206                "null_ordering": self.null_ordering,
207                **opts,
208            }
209        )
def rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
215def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
216    return lambda self, expression: self.func(name, *flatten(expression.args.values()))
def approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
219def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
220    if expression.args.get("accuracy"):
221        self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
222    return self.func("APPROX_COUNT_DISTINCT", expression.this)
def if_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.If) -> str:
225def if_sql(self: Generator, expression: exp.If) -> str:
226    return self.func(
227        "IF", expression.this, expression.args.get("true"), expression.args.get("false")
228    )
def arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
231def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
232    return self.binary(expression, "->")
def arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
235def arrow_json_extract_scalar_sql(
236    self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
237) -> str:
238    return self.binary(expression, "->>")
def inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
241def inline_array_sql(self: Generator, expression: exp.Array) -> str:
242    return f"[{self.expressions(expression)}]"
def no_ilike_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ILike) -> str:
245def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
246    return self.like_sql(
247        exp.Like(
248            this=exp.Lower(this=expression.this),
249            expression=expression.args["expression"],
250        )
251    )
def no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
254def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
255    zone = self.sql(expression, "this")
256    return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
259def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
260    if expression.args.get("recursive"):
261        self.unsupported("Recursive CTEs are unsupported")
262        expression.args["recursive"] = False
263    return self.with_sql(expression)
def no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
266def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
267    n = self.sql(expression, "this")
268    d = self.sql(expression, "expression")
269    return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
272def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
273    self.unsupported("TABLESAMPLE unsupported")
274    return self.sql(expression.this)
def no_pivot_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Pivot) -> str:
277def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
278    self.unsupported("PIVOT unsupported")
279    return ""
def no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
282def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
283    return self.cast_sql(expression)
def no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
286def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
287    self.unsupported("Properties unsupported")
288    return ""
def no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
291def no_comment_column_constraint_sql(
292    self: Generator, expression: exp.CommentColumnConstraint
293) -> str:
294    self.unsupported("CommentColumnConstraint unsupported")
295    return ""
def str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
298def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
299    this = self.sql(expression, "this")
300    substr = self.sql(expression, "substr")
301    position = self.sql(expression, "position")
302    if position:
303        return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
304    return f"STRPOS({this}, {substr})"
def struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
307def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
308    this = self.sql(expression, "this")
309    struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
310    return f"{this}.{struct_key}"
def var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
313def var_map_sql(
314    self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
315) -> str:
316    keys = expression.args["keys"]
317    values = expression.args["values"]
318
319    if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
320        self.unsupported("Cannot convert array columns into map.")
321        return self.func(map_func_name, keys, values)
322
323    args = []
324    for key, value in zip(keys.expressions, values.expressions):
325        args.append(self.sql(key))
326        args.append(self.sql(value))
327    return self.func(map_func_name, *args)
def format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
330def format_time_lambda(
331    exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
332) -> t.Callable[[t.Sequence], E]:
333    """Helper used for time expressions.
334
335    Args:
336        exp_class: the expression class to instantiate.
337        dialect: target sql dialect.
338        default: the default format, True being time.
339
340    Returns:
341        A callable that can be used to return the appropriately formatted time expression.
342    """
343
344    def _format_time(args: t.Sequence):
345        return exp_class(
346            this=seq_get(args, 0),
347            format=Dialect[dialect].format_time(
348                seq_get(args, 1)
349                or (Dialect[dialect].time_format if default is True else default or None)
350            ),
351        )
352
353    return _format_time

Helper used for time expressions.

Arguments:
  • exp_class: the expression class to instantiate.
  • dialect: target sql dialect.
  • default: the default format, True being time.
Returns:

A callable that can be used to return the appropriately formatted time expression.

def create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
356def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
357    """
358    In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
359    PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
360    columns are removed from the create statement.
361    """
362    has_schema = isinstance(expression.this, exp.Schema)
363    is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
364
365    if has_schema and is_partitionable:
366        expression = expression.copy()
367        prop = expression.find(exp.PartitionedByProperty)
368        if prop and prop.this and not isinstance(prop.this, exp.Schema):
369            schema = expression.this
370            columns = {v.name.upper() for v in prop.this.expressions}
371            partitions = [col for col in schema.expressions if col.name.upper() in columns]
372            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
373            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
374            expression.set("this", schema)
375
376    return self.create_sql(expression)

In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
379def parse_date_delta(
380    exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
381) -> t.Callable[[t.Sequence], E]:
382    def inner_func(args: t.Sequence) -> E:
383        unit_based = len(args) == 3
384        this = args[2] if unit_based else seq_get(args, 0)
385        unit = args[0] if unit_based else exp.Literal.string("DAY")
386        unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
387        return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
388
389    return inner_func
def parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
392def parse_date_delta_with_interval(
393    expression_class: t.Type[E],
394) -> t.Callable[[t.Sequence], t.Optional[E]]:
395    def func(args: t.Sequence) -> t.Optional[E]:
396        if len(args) < 2:
397            return None
398
399        interval = args[1]
400        expression = interval.this
401        if expression and expression.is_string:
402            expression = exp.Literal.number(expression.this)
403
404        return expression_class(
405            this=args[0],
406            expression=expression,
407            unit=exp.Literal.string(interval.text("unit")),
408        )
409
410    return func
def date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
413def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
414    unit = seq_get(args, 0)
415    this = seq_get(args, 1)
416
417    if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
418        return exp.DateTrunc(unit=unit, this=this)
419    return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
422def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
423    return self.func(
424        "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
425    )
def locate_to_strposition(args: Sequence) -> sqlglot.expressions.Expression:
428def locate_to_strposition(args: t.Sequence) -> exp.Expression:
429    return exp.StrPosition(
430        this=seq_get(args, 1),
431        substr=seq_get(args, 0),
432        position=seq_get(args, 2),
433    )
def strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
436def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
437    return self.func(
438        "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position")
439    )
def timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
442def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
443    return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
446def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
447    return f"CAST({self.sql(expression, 'this')} AS DATE)"
def min_or_least( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Min) -> str:
450def min_or_least(self: Generator, expression: exp.Min) -> str:
451    name = "LEAST" if expression.expressions else "MIN"
452    return rename_func(name)(self, expression)
def max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
455def max_or_greatest(self: Generator, expression: exp.Max) -> str:
456    name = "GREATEST" if expression.expressions else "MAX"
457    return rename_func(name)(self, expression)
def count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
460def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
461    cond = expression.this
462
463    if isinstance(expression.this, exp.Distinct):
464        cond = expression.this.expressions[0]
465        self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
466
467    return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Trim) -> str:
470def trim_sql(self: Generator, expression: exp.Trim) -> str:
471    target = self.sql(expression, "this")
472    trim_type = self.sql(expression, "position")
473    remove_chars = self.sql(expression, "expression")
474    collation = self.sql(expression, "collation")
475
476    # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
477    if not remove_chars and not collation:
478        return self.trim_sql(expression)
479
480    trim_type = f"{trim_type} " if trim_type else ""
481    remove_chars = f"{remove_chars} " if remove_chars else ""
482    from_part = "FROM " if trim_type or remove_chars else ""
483    collation = f" COLLATE {collation}" if collation else ""
484    return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def str_to_time_sql(self, expression: sqlglot.expressions.Expression) -> str:
487def str_to_time_sql(self, expression: exp.Expression) -> str:
488    return self.func("STRPTIME", expression.this, self.format_time(expression))
def ts_or_ds_to_date_sql(dialect: str) -> Callable:
491def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
492    def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
493        _dialect = Dialect.get_or_raise(dialect)
494        time_format = self.format_time(expression)
495        if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
496            return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
497        return f"CAST({self.sql(expression, 'this')} AS DATE)"
498
499    return _ts_or_ds_to_date_sql