sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot.generator import Generator 8from sqlglot.helper import flatten, seq_get 9from sqlglot.parser import Parser 10from sqlglot.time import format_time 11from sqlglot.tokens import Token, Tokenizer 12from sqlglot.trie import new_trie 13 14E = t.TypeVar("E", bound=exp.Expression) 15 16 17class Dialects(str, Enum): 18 DIALECT = "" 19 20 BIGQUERY = "bigquery" 21 CLICKHOUSE = "clickhouse" 22 DUCKDB = "duckdb" 23 HIVE = "hive" 24 MYSQL = "mysql" 25 ORACLE = "oracle" 26 POSTGRES = "postgres" 27 PRESTO = "presto" 28 REDSHIFT = "redshift" 29 SNOWFLAKE = "snowflake" 30 SPARK = "spark" 31 SPARK2 = "spark2" 32 SQLITE = "sqlite" 33 STARROCKS = "starrocks" 34 TABLEAU = "tableau" 35 TRINO = "trino" 36 TSQL = "tsql" 37 DATABRICKS = "databricks" 38 DRILL = "drill" 39 TERADATA = "teradata" 40 41 42class _Dialect(type): 43 classes: t.Dict[str, t.Type[Dialect]] = {} 44 45 def __eq__(cls, other: t.Any) -> bool: 46 if cls is other: 47 return True 48 if isinstance(other, str): 49 return cls is cls.get(other) 50 if isinstance(other, Dialect): 51 return cls is type(other) 52 53 return False 54 55 def __hash__(cls) -> int: 56 return hash(cls.__name__.lower()) 57 58 @classmethod 59 def __getitem__(cls, key: str) -> t.Type[Dialect]: 60 return cls.classes[key] 61 62 @classmethod 63 def get( 64 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 65 ) -> t.Optional[t.Type[Dialect]]: 66 return cls.classes.get(key, default) 67 68 def __new__(cls, clsname, bases, attrs): 69 klass = super().__new__(cls, clsname, bases, attrs) 70 enum = Dialects.__members__.get(clsname.upper()) 71 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 72 73 klass.time_trie = new_trie(klass.time_mapping) 74 klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} 75 klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) 76 77 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 78 klass.parser_class = getattr(klass, "Parser", Parser) 79 klass.generator_class = getattr(klass, "Generator", Generator) 80 81 klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] 82 klass.identifier_start, klass.identifier_end = list( 83 klass.tokenizer_class._IDENTIFIERS.items() 84 )[0] 85 86 klass.bit_start, klass.bit_end = seq_get( 87 list(klass.tokenizer_class._BIT_STRINGS.items()), 0 88 ) or (None, None) 89 90 klass.hex_start, klass.hex_end = seq_get( 91 list(klass.tokenizer_class._HEX_STRINGS.items()), 0 92 ) or (None, None) 93 94 klass.byte_start, klass.byte_end = seq_get( 95 list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 96 ) or (None, None) 97 98 return klass 99 100 101class Dialect(metaclass=_Dialect): 102 index_offset = 0 103 unnest_column_only = False 104 alias_post_tablesample = False 105 normalize_functions: t.Optional[str] = "upper" 106 null_ordering = "nulls_are_small" 107 108 date_format = "'%Y-%m-%d'" 109 dateint_format = "'%Y%m%d'" 110 time_format = "'%Y-%m-%d %H:%M:%S'" 111 time_mapping: t.Dict[str, str] = {} 112 113 # autofilled 114 quote_start = None 115 quote_end = None 116 identifier_start = None 117 identifier_end = None 118 119 time_trie = None 120 inverse_time_mapping = None 121 inverse_time_trie = None 122 tokenizer_class = None 123 parser_class = None 124 generator_class = None 125 126 def __eq__(self, other: t.Any) -> bool: 127 return type(self) == other 128 129 def __hash__(self) -> int: 130 return hash(type(self)) 131 132 @classmethod 133 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 134 if not dialect: 135 return cls 136 if isinstance(dialect, _Dialect): 137 return dialect 138 if isinstance(dialect, Dialect): 139 return dialect.__class__ 140 141 result = cls.get(dialect) 142 if not result: 143 raise ValueError(f"Unknown dialect '{dialect}'") 144 145 return result 146 147 @classmethod 148 def format_time( 149 cls, expression: t.Optional[str | exp.Expression] 150 ) -> t.Optional[exp.Expression]: 151 if isinstance(expression, str): 152 return exp.Literal.string( 153 format_time( 154 expression[1:-1], # the time formats are quoted 155 cls.time_mapping, 156 cls.time_trie, 157 ) 158 ) 159 if expression and expression.is_string: 160 return exp.Literal.string( 161 format_time( 162 expression.this, 163 cls.time_mapping, 164 cls.time_trie, 165 ) 166 ) 167 return expression 168 169 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 170 return self.parser(**opts).parse(self.tokenize(sql), sql) 171 172 def parse_into( 173 self, expression_type: exp.IntoType, sql: str, **opts 174 ) -> t.List[t.Optional[exp.Expression]]: 175 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 176 177 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 178 return self.generator(**opts).generate(expression) 179 180 def transpile(self, sql: str, **opts) -> t.List[str]: 181 return [self.generate(expression, **opts) for expression in self.parse(sql)] 182 183 def tokenize(self, sql: str) -> t.List[Token]: 184 return self.tokenizer.tokenize(sql) 185 186 @property 187 def tokenizer(self) -> Tokenizer: 188 if not hasattr(self, "_tokenizer"): 189 self._tokenizer = self.tokenizer_class() # type: ignore 190 return self._tokenizer 191 192 def parser(self, **opts) -> Parser: 193 return self.parser_class( # type: ignore 194 **{ 195 "index_offset": self.index_offset, 196 "unnest_column_only": self.unnest_column_only, 197 "alias_post_tablesample": self.alias_post_tablesample, 198 "null_ordering": self.null_ordering, 199 **opts, 200 }, 201 ) 202 203 def generator(self, **opts) -> Generator: 204 return self.generator_class( # type: ignore 205 **{ 206 "quote_start": self.quote_start, 207 "quote_end": self.quote_end, 208 "bit_start": self.bit_start, 209 "bit_end": self.bit_end, 210 "hex_start": self.hex_start, 211 "hex_end": self.hex_end, 212 "byte_start": self.byte_start, 213 "byte_end": self.byte_end, 214 "identifier_start": self.identifier_start, 215 "identifier_end": self.identifier_end, 216 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 217 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 218 "index_offset": self.index_offset, 219 "time_mapping": self.inverse_time_mapping, 220 "time_trie": self.inverse_time_trie, 221 "unnest_column_only": self.unnest_column_only, 222 "alias_post_tablesample": self.alias_post_tablesample, 223 "normalize_functions": self.normalize_functions, 224 "null_ordering": self.null_ordering, 225 **opts, 226 } 227 ) 228 229 230DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 231 232 233def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 234 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 235 236 237def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 238 if expression.args.get("accuracy"): 239 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 240 return self.func("APPROX_COUNT_DISTINCT", expression.this) 241 242 243def if_sql(self: Generator, expression: exp.If) -> str: 244 return self.func( 245 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 246 ) 247 248 249def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 250 return self.binary(expression, "->") 251 252 253def arrow_json_extract_scalar_sql( 254 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 255) -> str: 256 return self.binary(expression, "->>") 257 258 259def inline_array_sql(self: Generator, expression: exp.Array) -> str: 260 return f"[{self.expressions(expression)}]" 261 262 263def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 264 return self.like_sql( 265 exp.Like( 266 this=exp.Lower(this=expression.this), 267 expression=expression.args["expression"], 268 ) 269 ) 270 271 272def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 273 zone = self.sql(expression, "this") 274 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 275 276 277def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 278 if expression.args.get("recursive"): 279 self.unsupported("Recursive CTEs are unsupported") 280 expression.args["recursive"] = False 281 return self.with_sql(expression) 282 283 284def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 285 n = self.sql(expression, "this") 286 d = self.sql(expression, "expression") 287 return f"IF({d} <> 0, {n} / {d}, NULL)" 288 289 290def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 291 self.unsupported("TABLESAMPLE unsupported") 292 return self.sql(expression.this) 293 294 295def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 296 self.unsupported("PIVOT unsupported") 297 return "" 298 299 300def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 301 return self.cast_sql(expression) 302 303 304def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 305 self.unsupported("Properties unsupported") 306 return "" 307 308 309def no_comment_column_constraint_sql( 310 self: Generator, expression: exp.CommentColumnConstraint 311) -> str: 312 self.unsupported("CommentColumnConstraint unsupported") 313 return "" 314 315 316def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 317 this = self.sql(expression, "this") 318 substr = self.sql(expression, "substr") 319 position = self.sql(expression, "position") 320 if position: 321 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 322 return f"STRPOS({this}, {substr})" 323 324 325def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 326 this = self.sql(expression, "this") 327 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 328 return f"{this}.{struct_key}" 329 330 331def var_map_sql( 332 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 333) -> str: 334 keys = expression.args["keys"] 335 values = expression.args["values"] 336 337 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 338 self.unsupported("Cannot convert array columns into map.") 339 return self.func(map_func_name, keys, values) 340 341 args = [] 342 for key, value in zip(keys.expressions, values.expressions): 343 args.append(self.sql(key)) 344 args.append(self.sql(value)) 345 return self.func(map_func_name, *args) 346 347 348def format_time_lambda( 349 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 350) -> t.Callable[[t.Sequence], E]: 351 """Helper used for time expressions. 352 353 Args: 354 exp_class: the expression class to instantiate. 355 dialect: target sql dialect. 356 default: the default format, True being time. 357 358 Returns: 359 A callable that can be used to return the appropriately formatted time expression. 360 """ 361 362 def _format_time(args: t.Sequence): 363 return exp_class( 364 this=seq_get(args, 0), 365 format=Dialect[dialect].format_time( 366 seq_get(args, 1) 367 or (Dialect[dialect].time_format if default is True else default or None) 368 ), 369 ) 370 371 return _format_time 372 373 374def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 375 """ 376 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 377 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 378 columns are removed from the create statement. 379 """ 380 has_schema = isinstance(expression.this, exp.Schema) 381 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 382 383 if has_schema and is_partitionable: 384 expression = expression.copy() 385 prop = expression.find(exp.PartitionedByProperty) 386 if prop and prop.this and not isinstance(prop.this, exp.Schema): 387 schema = expression.this 388 columns = {v.name.upper() for v in prop.this.expressions} 389 partitions = [col for col in schema.expressions if col.name.upper() in columns] 390 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 391 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 392 expression.set("this", schema) 393 394 return self.create_sql(expression) 395 396 397def parse_date_delta( 398 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 399) -> t.Callable[[t.Sequence], E]: 400 def inner_func(args: t.Sequence) -> E: 401 unit_based = len(args) == 3 402 this = args[2] if unit_based else seq_get(args, 0) 403 unit = args[0] if unit_based else exp.Literal.string("DAY") 404 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 405 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 406 407 return inner_func 408 409 410def parse_date_delta_with_interval( 411 expression_class: t.Type[E], 412) -> t.Callable[[t.Sequence], t.Optional[E]]: 413 def func(args: t.Sequence) -> t.Optional[E]: 414 if len(args) < 2: 415 return None 416 417 interval = args[1] 418 expression = interval.this 419 if expression and expression.is_string: 420 expression = exp.Literal.number(expression.this) 421 422 return expression_class( 423 this=args[0], 424 expression=expression, 425 unit=exp.Literal.string(interval.text("unit")), 426 ) 427 428 return func 429 430 431def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: 432 unit = seq_get(args, 0) 433 this = seq_get(args, 1) 434 435 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 436 return exp.DateTrunc(unit=unit, this=this) 437 return exp.TimestampTrunc(this=this, unit=unit) 438 439 440def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 441 return self.func( 442 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 443 ) 444 445 446def locate_to_strposition(args: t.Sequence) -> exp.Expression: 447 return exp.StrPosition( 448 this=seq_get(args, 1), 449 substr=seq_get(args, 0), 450 position=seq_get(args, 2), 451 ) 452 453 454def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 455 return self.func( 456 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 457 ) 458 459 460def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 461 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 462 463 464def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 465 return f"CAST({self.sql(expression, 'this')} AS DATE)" 466 467 468def min_or_least(self: Generator, expression: exp.Min) -> str: 469 name = "LEAST" if expression.expressions else "MIN" 470 return rename_func(name)(self, expression) 471 472 473def max_or_greatest(self: Generator, expression: exp.Max) -> str: 474 name = "GREATEST" if expression.expressions else "MAX" 475 return rename_func(name)(self, expression) 476 477 478def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 479 cond = expression.this 480 481 if isinstance(expression.this, exp.Distinct): 482 cond = expression.this.expressions[0] 483 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 484 485 return self.func("sum", exp.func("if", cond, 1, 0)) 486 487 488def trim_sql(self: Generator, expression: exp.Trim) -> str: 489 target = self.sql(expression, "this") 490 trim_type = self.sql(expression, "position") 491 remove_chars = self.sql(expression, "expression") 492 collation = self.sql(expression, "collation") 493 494 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 495 if not remove_chars and not collation: 496 return self.trim_sql(expression) 497 498 trim_type = f"{trim_type} " if trim_type else "" 499 remove_chars = f"{remove_chars} " if remove_chars else "" 500 from_part = "FROM " if trim_type or remove_chars else "" 501 collation = f" COLLATE {collation}" if collation else "" 502 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 503 504 505def str_to_time_sql(self, expression: exp.Expression) -> str: 506 return self.func("STRPTIME", expression.this, self.format_time(expression)) 507 508 509def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 510 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 511 _dialect = Dialect.get_or_raise(dialect) 512 time_format = self.format_time(expression) 513 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 514 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 515 return f"CAST({self.sql(expression, 'this')} AS DATE)" 516 517 return _ts_or_ds_to_date_sql 518 519 520# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator 521def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 522 names = [] 523 for agg in aggregations: 524 if isinstance(agg, exp.Alias): 525 names.append(agg.alias) 526 else: 527 """ 528 This case corresponds to aggregations without aliases being used as suffixes 529 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 530 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 531 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 532 """ 533 agg_all_unquoted = agg.transform( 534 lambda node: exp.Identifier(this=node.name, quoted=False) 535 if isinstance(node, exp.Identifier) 536 else node 537 ) 538 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 539 540 return names
class
Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum): 19 DIALECT = "" 20 21 BIGQUERY = "bigquery" 22 CLICKHOUSE = "clickhouse" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TRINO = "trino" 37 TSQL = "tsql" 38 DATABRICKS = "databricks" 39 DRILL = "drill" 40 TERADATA = "teradata"
An enumeration.
DIALECT =
<Dialects.DIALECT: ''>
BIGQUERY =
<Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE =
<Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB =
<Dialects.DUCKDB: 'duckdb'>
HIVE =
<Dialects.HIVE: 'hive'>
MYSQL =
<Dialects.MYSQL: 'mysql'>
ORACLE =
<Dialects.ORACLE: 'oracle'>
POSTGRES =
<Dialects.POSTGRES: 'postgres'>
PRESTO =
<Dialects.PRESTO: 'presto'>
REDSHIFT =
<Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE =
<Dialects.SNOWFLAKE: 'snowflake'>
SPARK =
<Dialects.SPARK: 'spark'>
SPARK2 =
<Dialects.SPARK2: 'spark2'>
SQLITE =
<Dialects.SQLITE: 'sqlite'>
STARROCKS =
<Dialects.STARROCKS: 'starrocks'>
TABLEAU =
<Dialects.TABLEAU: 'tableau'>
TRINO =
<Dialects.TRINO: 'trino'>
TSQL =
<Dialects.TSQL: 'tsql'>
DATABRICKS =
<Dialects.DATABRICKS: 'databricks'>
DRILL =
<Dialects.DRILL: 'drill'>
TERADATA =
<Dialects.TERADATA: 'teradata'>
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
class
Dialect:
102class Dialect(metaclass=_Dialect): 103 index_offset = 0 104 unnest_column_only = False 105 alias_post_tablesample = False 106 normalize_functions: t.Optional[str] = "upper" 107 null_ordering = "nulls_are_small" 108 109 date_format = "'%Y-%m-%d'" 110 dateint_format = "'%Y%m%d'" 111 time_format = "'%Y-%m-%d %H:%M:%S'" 112 time_mapping: t.Dict[str, str] = {} 113 114 # autofilled 115 quote_start = None 116 quote_end = None 117 identifier_start = None 118 identifier_end = None 119 120 time_trie = None 121 inverse_time_mapping = None 122 inverse_time_trie = None 123 tokenizer_class = None 124 parser_class = None 125 generator_class = None 126 127 def __eq__(self, other: t.Any) -> bool: 128 return type(self) == other 129 130 def __hash__(self) -> int: 131 return hash(type(self)) 132 133 @classmethod 134 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 135 if not dialect: 136 return cls 137 if isinstance(dialect, _Dialect): 138 return dialect 139 if isinstance(dialect, Dialect): 140 return dialect.__class__ 141 142 result = cls.get(dialect) 143 if not result: 144 raise ValueError(f"Unknown dialect '{dialect}'") 145 146 return result 147 148 @classmethod 149 def format_time( 150 cls, expression: t.Optional[str | exp.Expression] 151 ) -> t.Optional[exp.Expression]: 152 if isinstance(expression, str): 153 return exp.Literal.string( 154 format_time( 155 expression[1:-1], # the time formats are quoted 156 cls.time_mapping, 157 cls.time_trie, 158 ) 159 ) 160 if expression and expression.is_string: 161 return exp.Literal.string( 162 format_time( 163 expression.this, 164 cls.time_mapping, 165 cls.time_trie, 166 ) 167 ) 168 return expression 169 170 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 171 return self.parser(**opts).parse(self.tokenize(sql), sql) 172 173 def parse_into( 174 self, expression_type: exp.IntoType, sql: str, **opts 175 ) -> t.List[t.Optional[exp.Expression]]: 176 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 177 178 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 179 return self.generator(**opts).generate(expression) 180 181 def transpile(self, sql: str, **opts) -> t.List[str]: 182 return [self.generate(expression, **opts) for expression in self.parse(sql)] 183 184 def tokenize(self, sql: str) -> t.List[Token]: 185 return self.tokenizer.tokenize(sql) 186 187 @property 188 def tokenizer(self) -> Tokenizer: 189 if not hasattr(self, "_tokenizer"): 190 self._tokenizer = self.tokenizer_class() # type: ignore 191 return self._tokenizer 192 193 def parser(self, **opts) -> Parser: 194 return self.parser_class( # type: ignore 195 **{ 196 "index_offset": self.index_offset, 197 "unnest_column_only": self.unnest_column_only, 198 "alias_post_tablesample": self.alias_post_tablesample, 199 "null_ordering": self.null_ordering, 200 **opts, 201 }, 202 ) 203 204 def generator(self, **opts) -> Generator: 205 return self.generator_class( # type: ignore 206 **{ 207 "quote_start": self.quote_start, 208 "quote_end": self.quote_end, 209 "bit_start": self.bit_start, 210 "bit_end": self.bit_end, 211 "hex_start": self.hex_start, 212 "hex_end": self.hex_end, 213 "byte_start": self.byte_start, 214 "byte_end": self.byte_end, 215 "identifier_start": self.identifier_start, 216 "identifier_end": self.identifier_end, 217 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 218 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 219 "index_offset": self.index_offset, 220 "time_mapping": self.inverse_time_mapping, 221 "time_trie": self.inverse_time_trie, 222 "unnest_column_only": self.unnest_column_only, 223 "alias_post_tablesample": self.alias_post_tablesample, 224 "normalize_functions": self.normalize_functions, 225 "null_ordering": self.null_ordering, 226 **opts, 227 } 228 )
@classmethod
def
get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
133 @classmethod 134 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 135 if not dialect: 136 return cls 137 if isinstance(dialect, _Dialect): 138 return dialect 139 if isinstance(dialect, Dialect): 140 return dialect.__class__ 141 142 result = cls.get(dialect) 143 if not result: 144 raise ValueError(f"Unknown dialect '{dialect}'") 145 146 return result
@classmethod
def
format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
148 @classmethod 149 def format_time( 150 cls, expression: t.Optional[str | exp.Expression] 151 ) -> t.Optional[exp.Expression]: 152 if isinstance(expression, str): 153 return exp.Literal.string( 154 format_time( 155 expression[1:-1], # the time formats are quoted 156 cls.time_mapping, 157 cls.time_trie, 158 ) 159 ) 160 if expression and expression.is_string: 161 return exp.Literal.string( 162 format_time( 163 expression.this, 164 cls.time_mapping, 165 cls.time_trie, 166 ) 167 ) 168 return expression
def
parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
193 def parser(self, **opts) -> Parser: 194 return self.parser_class( # type: ignore 195 **{ 196 "index_offset": self.index_offset, 197 "unnest_column_only": self.unnest_column_only, 198 "alias_post_tablesample": self.alias_post_tablesample, 199 "null_ordering": self.null_ordering, 200 **opts, 201 }, 202 )
204 def generator(self, **opts) -> Generator: 205 return self.generator_class( # type: ignore 206 **{ 207 "quote_start": self.quote_start, 208 "quote_end": self.quote_end, 209 "bit_start": self.bit_start, 210 "bit_end": self.bit_end, 211 "hex_start": self.hex_start, 212 "hex_end": self.hex_end, 213 "byte_start": self.byte_start, 214 "byte_end": self.byte_end, 215 "identifier_start": self.identifier_start, 216 "identifier_end": self.identifier_end, 217 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 218 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 219 "index_offset": self.index_offset, 220 "time_mapping": self.inverse_time_mapping, 221 "time_trie": self.inverse_time_trie, 222 "unnest_column_only": self.unnest_column_only, 223 "alias_post_tablesample": self.alias_post_tablesample, 224 "normalize_functions": self.normalize_functions, 225 "null_ordering": self.null_ordering, 226 **opts, 227 } 228 )
def
rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
def
approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
def
arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
def
arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
def
inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
def
no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
def
no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
def
no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
def
no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
def
no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
def
no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
def
no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
def
str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
317def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 318 this = self.sql(expression, "this") 319 substr = self.sql(expression, "substr") 320 position = self.sql(expression, "position") 321 if position: 322 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 323 return f"STRPOS({this}, {substr})"
def
struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
def
var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
332def var_map_sql( 333 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 334) -> str: 335 keys = expression.args["keys"] 336 values = expression.args["values"] 337 338 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 339 self.unsupported("Cannot convert array columns into map.") 340 return self.func(map_func_name, keys, values) 341 342 args = [] 343 for key, value in zip(keys.expressions, values.expressions): 344 args.append(self.sql(key)) 345 args.append(self.sql(value)) 346 return self.func(map_func_name, *args)
def
format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
349def format_time_lambda( 350 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 351) -> t.Callable[[t.Sequence], E]: 352 """Helper used for time expressions. 353 354 Args: 355 exp_class: the expression class to instantiate. 356 dialect: target sql dialect. 357 default: the default format, True being time. 358 359 Returns: 360 A callable that can be used to return the appropriately formatted time expression. 361 """ 362 363 def _format_time(args: t.Sequence): 364 return exp_class( 365 this=seq_get(args, 0), 366 format=Dialect[dialect].format_time( 367 seq_get(args, 1) 368 or (Dialect[dialect].time_format if default is True else default or None) 369 ), 370 ) 371 372 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
def
create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
375def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 376 """ 377 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 378 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 379 columns are removed from the create statement. 380 """ 381 has_schema = isinstance(expression.this, exp.Schema) 382 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 383 384 if has_schema and is_partitionable: 385 expression = expression.copy() 386 prop = expression.find(exp.PartitionedByProperty) 387 if prop and prop.this and not isinstance(prop.this, exp.Schema): 388 schema = expression.this 389 columns = {v.name.upper() for v in prop.this.expressions} 390 partitions = [col for col in schema.expressions if col.name.upper() in columns] 391 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 392 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 393 expression.set("this", schema) 394 395 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
def
parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
398def parse_date_delta( 399 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 400) -> t.Callable[[t.Sequence], E]: 401 def inner_func(args: t.Sequence) -> E: 402 unit_based = len(args) == 3 403 this = args[2] if unit_based else seq_get(args, 0) 404 unit = args[0] if unit_based else exp.Literal.string("DAY") 405 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 406 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 407 408 return inner_func
def
parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
411def parse_date_delta_with_interval( 412 expression_class: t.Type[E], 413) -> t.Callable[[t.Sequence], t.Optional[E]]: 414 def func(args: t.Sequence) -> t.Optional[E]: 415 if len(args) < 2: 416 return None 417 418 interval = args[1] 419 expression = interval.this 420 if expression and expression.is_string: 421 expression = exp.Literal.number(expression.this) 422 423 return expression_class( 424 this=args[0], 425 expression=expression, 426 unit=exp.Literal.string(interval.text("unit")), 427 ) 428 429 return func
def
date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
432def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: 433 unit = seq_get(args, 0) 434 this = seq_get(args, 1) 435 436 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 437 return exp.DateTrunc(unit=unit, this=this) 438 return exp.TimestampTrunc(this=this, unit=unit)
def
timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
def
strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
def
timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
def
datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
def
max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
def
count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
479def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 480 cond = expression.this 481 482 if isinstance(expression.this, exp.Distinct): 483 cond = expression.this.expressions[0] 484 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 485 486 return self.func("sum", exp.func("if", cond, 1, 0))
489def trim_sql(self: Generator, expression: exp.Trim) -> str: 490 target = self.sql(expression, "this") 491 trim_type = self.sql(expression, "position") 492 remove_chars = self.sql(expression, "expression") 493 collation = self.sql(expression, "collation") 494 495 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 496 if not remove_chars and not collation: 497 return self.trim_sql(expression) 498 499 trim_type = f"{trim_type} " if trim_type else "" 500 remove_chars = f"{remove_chars} " if remove_chars else "" 501 from_part = "FROM " if trim_type or remove_chars else "" 502 collation = f" COLLATE {collation}" if collation else "" 503 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def
ts_or_ds_to_date_sql(dialect: str) -> Callable:
510def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 511 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 512 _dialect = Dialect.get_or_raise(dialect) 513 time_format = self.format_time(expression) 514 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 515 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 516 return f"CAST({self.sql(expression, 'this')} AS DATE)" 517 518 return _ts_or_ds_to_date_sql
def
pivot_column_names( aggregations: List[sqlglot.expressions.Expression], dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> List[str]:
522def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 523 names = [] 524 for agg in aggregations: 525 if isinstance(agg, exp.Alias): 526 names.append(agg.alias) 527 else: 528 """ 529 This case corresponds to aggregations without aliases being used as suffixes 530 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 531 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 532 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 533 """ 534 agg_all_unquoted = agg.transform( 535 lambda node: exp.Identifier(this=node.name, quoted=False) 536 if isinstance(node, exp.Identifier) 537 else node 538 ) 539 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 540 541 return names