sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 85 klass.parser_class = getattr(klass, "Parser", Parser) 86 klass.generator_class = getattr(klass, "Generator", Generator) 87 88 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 89 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 90 klass.tokenizer_class._IDENTIFIERS.items() 91 )[0] 92 93 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 94 return next( 95 ( 96 (s, e) 97 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 98 if t == token_type 99 ), 100 (None, None), 101 ) 102 103 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 104 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 105 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 106 107 dialect_properties = { 108 **{ 109 k: v 110 for k, v in vars(klass).items() 111 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 112 }, 113 "TOKENIZER_CLASS": klass.tokenizer_class, 114 } 115 116 if enum not in ("", "bigquery"): 117 dialect_properties["SELECT_KINDS"] = () 118 119 # Pass required dialect properties to the tokenizer, parser and generator classes 120 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 121 for name, value in dialect_properties.items(): 122 if hasattr(subclass, name): 123 setattr(subclass, name, value) 124 125 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 126 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 127 128 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 129 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 130 TokenType.ANTI, 131 TokenType.SEMI, 132 } 133 134 klass.generator_class.can_identify = klass.can_identify 135 136 return klass 137 138 139class Dialect(metaclass=_Dialect): 140 # Determines the base index offset for arrays 141 INDEX_OFFSET = 0 142 143 # If true unnest table aliases are considered only as column aliases 144 UNNEST_COLUMN_ONLY = False 145 146 # Determines whether or not the table alias comes after tablesample 147 ALIAS_POST_TABLESAMPLE = False 148 149 # Determines whether or not unquoted identifiers are resolved as uppercase 150 # When set to None, it means that the dialect treats all identifiers as case-insensitive 151 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 152 153 # Determines whether or not an unquoted identifier can start with a digit 154 IDENTIFIERS_CAN_START_WITH_DIGIT = False 155 156 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 157 DPIPE_IS_STRING_CONCAT = True 158 159 # Determines whether or not CONCAT's arguments must be strings 160 STRICT_STRING_CONCAT = False 161 162 # Determines whether or not user-defined data types are supported 163 SUPPORTS_USER_DEFINED_TYPES = True 164 165 # Determines whether or not SEMI/ANTI JOINs are supported 166 SUPPORTS_SEMI_ANTI_JOIN = True 167 168 # Determines how function names are going to be normalized 169 NORMALIZE_FUNCTIONS: bool | str = "upper" 170 171 # Determines whether the base comes first in the LOG function 172 LOG_BASE_FIRST = True 173 174 # Indicates the default null ordering method to use if not explicitly set 175 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 176 NULL_ORDERING = "nulls_are_small" 177 178 DATE_FORMAT = "'%Y-%m-%d'" 179 DATEINT_FORMAT = "'%Y%m%d'" 180 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 181 182 # Custom time mappings in which the key represents dialect time format 183 # and the value represents a python time format 184 TIME_MAPPING: t.Dict[str, str] = {} 185 186 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 187 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 188 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 189 FORMAT_MAPPING: t.Dict[str, str] = {} 190 191 # Columns that are auto-generated by the engine corresponding to this dialect 192 # Such columns may be excluded from SELECT * queries, for example 193 PSEUDOCOLUMNS: t.Set[str] = set() 194 195 # Autofilled 196 tokenizer_class = Tokenizer 197 parser_class = Parser 198 generator_class = Generator 199 200 # A trie of the time_mapping keys 201 TIME_TRIE: t.Dict = {} 202 FORMAT_TRIE: t.Dict = {} 203 204 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 205 INVERSE_TIME_TRIE: t.Dict = {} 206 207 def __eq__(self, other: t.Any) -> bool: 208 return type(self) == other 209 210 def __hash__(self) -> int: 211 return hash(type(self)) 212 213 @classmethod 214 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 215 if not dialect: 216 return cls 217 if isinstance(dialect, _Dialect): 218 return dialect 219 if isinstance(dialect, Dialect): 220 return dialect.__class__ 221 222 result = cls.get(dialect) 223 if not result: 224 raise ValueError(f"Unknown dialect '{dialect}'") 225 226 return result 227 228 @classmethod 229 def format_time( 230 cls, expression: t.Optional[str | exp.Expression] 231 ) -> t.Optional[exp.Expression]: 232 if isinstance(expression, str): 233 return exp.Literal.string( 234 # the time formats are quoted 235 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 236 ) 237 238 if expression and expression.is_string: 239 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 240 241 return expression 242 243 @classmethod 244 def normalize_identifier(cls, expression: E) -> E: 245 """ 246 Normalizes an unquoted identifier to either lower or upper case, thus essentially 247 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 248 they will be normalized regardless of being quoted or not. 249 """ 250 if isinstance(expression, exp.Identifier) and ( 251 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 252 ): 253 expression.set( 254 "this", 255 expression.this.upper() 256 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 257 else expression.this.lower(), 258 ) 259 260 return expression 261 262 @classmethod 263 def case_sensitive(cls, text: str) -> bool: 264 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 265 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 266 return False 267 268 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 269 return any(unsafe(char) for char in text) 270 271 @classmethod 272 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 273 """Checks if text can be identified given an identify option. 274 275 Args: 276 text: The text to check. 277 identify: 278 "always" or `True`: Always returns true. 279 "safe": True if the identifier is case-insensitive. 280 281 Returns: 282 Whether or not the given text can be identified. 283 """ 284 if identify is True or identify == "always": 285 return True 286 287 if identify == "safe": 288 return not cls.case_sensitive(text) 289 290 return False 291 292 @classmethod 293 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 294 if isinstance(expression, exp.Identifier): 295 name = expression.this 296 expression.set( 297 "quoted", 298 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 299 ) 300 301 return expression 302 303 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 304 return self.parser(**opts).parse(self.tokenize(sql), sql) 305 306 def parse_into( 307 self, expression_type: exp.IntoType, sql: str, **opts 308 ) -> t.List[t.Optional[exp.Expression]]: 309 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 310 311 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 312 return self.generator(**opts).generate(expression) 313 314 def transpile(self, sql: str, **opts) -> t.List[str]: 315 return [self.generate(expression, **opts) for expression in self.parse(sql)] 316 317 def tokenize(self, sql: str) -> t.List[Token]: 318 return self.tokenizer.tokenize(sql) 319 320 @property 321 def tokenizer(self) -> Tokenizer: 322 if not hasattr(self, "_tokenizer"): 323 self._tokenizer = self.tokenizer_class() 324 return self._tokenizer 325 326 def parser(self, **opts) -> Parser: 327 return self.parser_class(**opts) 328 329 def generator(self, **opts) -> Generator: 330 return self.generator_class(**opts) 331 332 333DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 334 335 336def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 337 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 338 339 340def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 341 if expression.args.get("accuracy"): 342 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 343 return self.func("APPROX_COUNT_DISTINCT", expression.this) 344 345 346def if_sql( 347 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 348) -> t.Callable[[Generator, exp.If], str]: 349 def _if_sql(self: Generator, expression: exp.If) -> str: 350 return self.func( 351 name, 352 expression.this, 353 expression.args.get("true"), 354 expression.args.get("false") or false_value, 355 ) 356 357 return _if_sql 358 359 360def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 361 return self.binary(expression, "->") 362 363 364def arrow_json_extract_scalar_sql( 365 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 366) -> str: 367 return self.binary(expression, "->>") 368 369 370def inline_array_sql(self: Generator, expression: exp.Array) -> str: 371 return f"[{self.expressions(expression, flat=True)}]" 372 373 374def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 375 return self.like_sql( 376 exp.Like( 377 this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() 378 ) 379 ) 380 381 382def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 383 zone = self.sql(expression, "this") 384 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 385 386 387def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 388 if expression.args.get("recursive"): 389 self.unsupported("Recursive CTEs are unsupported") 390 expression.args["recursive"] = False 391 return self.with_sql(expression) 392 393 394def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 395 n = self.sql(expression, "this") 396 d = self.sql(expression, "expression") 397 return f"IF({d} <> 0, {n} / {d}, NULL)" 398 399 400def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 401 self.unsupported("TABLESAMPLE unsupported") 402 return self.sql(expression.this) 403 404 405def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 406 self.unsupported("PIVOT unsupported") 407 return "" 408 409 410def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 411 return self.cast_sql(expression) 412 413 414def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 415 self.unsupported("Properties unsupported") 416 return "" 417 418 419def no_comment_column_constraint_sql( 420 self: Generator, expression: exp.CommentColumnConstraint 421) -> str: 422 self.unsupported("CommentColumnConstraint unsupported") 423 return "" 424 425 426def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 427 self.unsupported("MAP_FROM_ENTRIES unsupported") 428 return "" 429 430 431def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 432 this = self.sql(expression, "this") 433 substr = self.sql(expression, "substr") 434 position = self.sql(expression, "position") 435 if position: 436 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 437 return f"STRPOS({this}, {substr})" 438 439 440def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 441 return ( 442 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 443 ) 444 445 446def var_map_sql( 447 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 448) -> str: 449 keys = expression.args["keys"] 450 values = expression.args["values"] 451 452 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 453 self.unsupported("Cannot convert array columns into map.") 454 return self.func(map_func_name, keys, values) 455 456 args = [] 457 for key, value in zip(keys.expressions, values.expressions): 458 args.append(self.sql(key)) 459 args.append(self.sql(value)) 460 461 return self.func(map_func_name, *args) 462 463 464def format_time_lambda( 465 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 466) -> t.Callable[[t.List], E]: 467 """Helper used for time expressions. 468 469 Args: 470 exp_class: the expression class to instantiate. 471 dialect: target sql dialect. 472 default: the default format, True being time. 473 474 Returns: 475 A callable that can be used to return the appropriately formatted time expression. 476 """ 477 478 def _format_time(args: t.List): 479 return exp_class( 480 this=seq_get(args, 0), 481 format=Dialect[dialect].format_time( 482 seq_get(args, 1) 483 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 484 ), 485 ) 486 487 return _format_time 488 489 490def time_format( 491 dialect: DialectType = None, 492) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 493 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 494 """ 495 Returns the time format for a given expression, unless it's equivalent 496 to the default time format of the dialect of interest. 497 """ 498 time_format = self.format_time(expression) 499 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 500 501 return _time_format 502 503 504def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 505 """ 506 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 507 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 508 columns are removed from the create statement. 509 """ 510 has_schema = isinstance(expression.this, exp.Schema) 511 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 512 513 if has_schema and is_partitionable: 514 expression = expression.copy() 515 prop = expression.find(exp.PartitionedByProperty) 516 if prop and prop.this and not isinstance(prop.this, exp.Schema): 517 schema = expression.this 518 columns = {v.name.upper() for v in prop.this.expressions} 519 partitions = [col for col in schema.expressions if col.name.upper() in columns] 520 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 521 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 522 expression.set("this", schema) 523 524 return self.create_sql(expression) 525 526 527def parse_date_delta( 528 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 529) -> t.Callable[[t.List], E]: 530 def inner_func(args: t.List) -> E: 531 unit_based = len(args) == 3 532 this = args[2] if unit_based else seq_get(args, 0) 533 unit = args[0] if unit_based else exp.Literal.string("DAY") 534 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 535 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 536 537 return inner_func 538 539 540def parse_date_delta_with_interval( 541 expression_class: t.Type[E], 542) -> t.Callable[[t.List], t.Optional[E]]: 543 def func(args: t.List) -> t.Optional[E]: 544 if len(args) < 2: 545 return None 546 547 interval = args[1] 548 549 if not isinstance(interval, exp.Interval): 550 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 551 552 expression = interval.this 553 if expression and expression.is_string: 554 expression = exp.Literal.number(expression.this) 555 556 return expression_class( 557 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 558 ) 559 560 return func 561 562 563def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 564 unit = seq_get(args, 0) 565 this = seq_get(args, 1) 566 567 if isinstance(this, exp.Cast) and this.is_type("date"): 568 return exp.DateTrunc(unit=unit, this=this) 569 return exp.TimestampTrunc(this=this, unit=unit) 570 571 572def date_add_interval_sql( 573 data_type: str, kind: str 574) -> t.Callable[[Generator, exp.Expression], str]: 575 def func(self: Generator, expression: exp.Expression) -> str: 576 this = self.sql(expression, "this") 577 unit = expression.args.get("unit") 578 unit = exp.var(unit.name.upper() if unit else "DAY") 579 interval = exp.Interval(this=expression.expression.copy(), unit=unit) 580 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 581 582 return func 583 584 585def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 586 return self.func( 587 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 588 ) 589 590 591def locate_to_strposition(args: t.List) -> exp.Expression: 592 return exp.StrPosition( 593 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 594 ) 595 596 597def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 598 return self.func( 599 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 600 ) 601 602 603def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 604 expression = expression.copy() 605 return self.sql( 606 exp.Substring( 607 this=expression.this, start=exp.Literal.number(1), length=expression.expression 608 ) 609 ) 610 611 612def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 613 expression = expression.copy() 614 return self.sql( 615 exp.Substring( 616 this=expression.this, 617 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 618 ) 619 ) 620 621 622def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 623 return self.sql(exp.cast(expression.this, "timestamp")) 624 625 626def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 627 return self.sql(exp.cast(expression.this, "date")) 628 629 630# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 631def encode_decode_sql( 632 self: Generator, expression: exp.Expression, name: str, replace: bool = True 633) -> str: 634 charset = expression.args.get("charset") 635 if charset and charset.name.lower() != "utf-8": 636 self.unsupported(f"Expected utf-8 character set, got {charset}.") 637 638 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 639 640 641def min_or_least(self: Generator, expression: exp.Min) -> str: 642 name = "LEAST" if expression.expressions else "MIN" 643 return rename_func(name)(self, expression) 644 645 646def max_or_greatest(self: Generator, expression: exp.Max) -> str: 647 name = "GREATEST" if expression.expressions else "MAX" 648 return rename_func(name)(self, expression) 649 650 651def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 652 cond = expression.this 653 654 if isinstance(expression.this, exp.Distinct): 655 cond = expression.this.expressions[0] 656 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 657 658 return self.func("sum", exp.func("if", cond.copy(), 1, 0)) 659 660 661def trim_sql(self: Generator, expression: exp.Trim) -> str: 662 target = self.sql(expression, "this") 663 trim_type = self.sql(expression, "position") 664 remove_chars = self.sql(expression, "expression") 665 collation = self.sql(expression, "collation") 666 667 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 668 if not remove_chars and not collation: 669 return self.trim_sql(expression) 670 671 trim_type = f"{trim_type} " if trim_type else "" 672 remove_chars = f"{remove_chars} " if remove_chars else "" 673 from_part = "FROM " if trim_type or remove_chars else "" 674 collation = f" COLLATE {collation}" if collation else "" 675 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 676 677 678def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 679 return self.func("STRPTIME", expression.this, self.format_time(expression)) 680 681 682def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 683 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 684 _dialect = Dialect.get_or_raise(dialect) 685 time_format = self.format_time(expression) 686 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 687 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 688 689 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 690 691 return _ts_or_ds_to_date_sql 692 693 694def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 695 expression = expression.copy() 696 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 697 698 699def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 700 expression = expression.copy() 701 delim, *rest_args = expression.expressions 702 return self.sql( 703 reduce( 704 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 705 rest_args, 706 ) 707 ) 708 709 710def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 711 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 712 if bad_args: 713 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 714 715 return self.func( 716 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 717 ) 718 719 720def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 721 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 722 if bad_args: 723 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 724 725 return self.func( 726 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 727 ) 728 729 730def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 731 names = [] 732 for agg in aggregations: 733 if isinstance(agg, exp.Alias): 734 names.append(agg.alias) 735 else: 736 """ 737 This case corresponds to aggregations without aliases being used as suffixes 738 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 739 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 740 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 741 """ 742 agg_all_unquoted = agg.transform( 743 lambda node: exp.Identifier(this=node.name, quoted=False) 744 if isinstance(node, exp.Identifier) 745 else node 746 ) 747 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 748 749 return names 750 751 752def simplify_literal(expression: E) -> E: 753 if not isinstance(expression.expression, exp.Literal): 754 from sqlglot.optimizer.simplify import simplify 755 756 simplify(expression.expression) 757 758 return expression 759 760 761def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 762 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 763 764 765# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 766def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 767 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 768 769 770def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 771 return self.func("MAX", expression.this) 772 773 774def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 775 a = self.sql(expression.left) 776 b = self.sql(expression.right) 777 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 778 779 780# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 781def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 782 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 783 784 785def is_parse_json(expression: exp.Expression) -> bool: 786 return isinstance(expression, exp.ParseJSON) or ( 787 isinstance(expression, exp.Cast) and expression.is_type("json") 788 ) 789 790 791def isnull_to_is_null(args: t.List) -> exp.Expression: 792 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 793 794 795def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: 796 if expression.expression.args.get("with"): 797 expression = expression.copy() 798 expression.set("with", expression.expression.args["with"].pop()) 799 return self.insert_sql(expression)
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
140class Dialect(metaclass=_Dialect): 141 # Determines the base index offset for arrays 142 INDEX_OFFSET = 0 143 144 # If true unnest table aliases are considered only as column aliases 145 UNNEST_COLUMN_ONLY = False 146 147 # Determines whether or not the table alias comes after tablesample 148 ALIAS_POST_TABLESAMPLE = False 149 150 # Determines whether or not unquoted identifiers are resolved as uppercase 151 # When set to None, it means that the dialect treats all identifiers as case-insensitive 152 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 153 154 # Determines whether or not an unquoted identifier can start with a digit 155 IDENTIFIERS_CAN_START_WITH_DIGIT = False 156 157 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 158 DPIPE_IS_STRING_CONCAT = True 159 160 # Determines whether or not CONCAT's arguments must be strings 161 STRICT_STRING_CONCAT = False 162 163 # Determines whether or not user-defined data types are supported 164 SUPPORTS_USER_DEFINED_TYPES = True 165 166 # Determines whether or not SEMI/ANTI JOINs are supported 167 SUPPORTS_SEMI_ANTI_JOIN = True 168 169 # Determines how function names are going to be normalized 170 NORMALIZE_FUNCTIONS: bool | str = "upper" 171 172 # Determines whether the base comes first in the LOG function 173 LOG_BASE_FIRST = True 174 175 # Indicates the default null ordering method to use if not explicitly set 176 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 177 NULL_ORDERING = "nulls_are_small" 178 179 DATE_FORMAT = "'%Y-%m-%d'" 180 DATEINT_FORMAT = "'%Y%m%d'" 181 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 182 183 # Custom time mappings in which the key represents dialect time format 184 # and the value represents a python time format 185 TIME_MAPPING: t.Dict[str, str] = {} 186 187 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 188 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 189 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 190 FORMAT_MAPPING: t.Dict[str, str] = {} 191 192 # Columns that are auto-generated by the engine corresponding to this dialect 193 # Such columns may be excluded from SELECT * queries, for example 194 PSEUDOCOLUMNS: t.Set[str] = set() 195 196 # Autofilled 197 tokenizer_class = Tokenizer 198 parser_class = Parser 199 generator_class = Generator 200 201 # A trie of the time_mapping keys 202 TIME_TRIE: t.Dict = {} 203 FORMAT_TRIE: t.Dict = {} 204 205 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 206 INVERSE_TIME_TRIE: t.Dict = {} 207 208 def __eq__(self, other: t.Any) -> bool: 209 return type(self) == other 210 211 def __hash__(self) -> int: 212 return hash(type(self)) 213 214 @classmethod 215 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 216 if not dialect: 217 return cls 218 if isinstance(dialect, _Dialect): 219 return dialect 220 if isinstance(dialect, Dialect): 221 return dialect.__class__ 222 223 result = cls.get(dialect) 224 if not result: 225 raise ValueError(f"Unknown dialect '{dialect}'") 226 227 return result 228 229 @classmethod 230 def format_time( 231 cls, expression: t.Optional[str | exp.Expression] 232 ) -> t.Optional[exp.Expression]: 233 if isinstance(expression, str): 234 return exp.Literal.string( 235 # the time formats are quoted 236 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 237 ) 238 239 if expression and expression.is_string: 240 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 241 242 return expression 243 244 @classmethod 245 def normalize_identifier(cls, expression: E) -> E: 246 """ 247 Normalizes an unquoted identifier to either lower or upper case, thus essentially 248 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 249 they will be normalized regardless of being quoted or not. 250 """ 251 if isinstance(expression, exp.Identifier) and ( 252 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 253 ): 254 expression.set( 255 "this", 256 expression.this.upper() 257 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 258 else expression.this.lower(), 259 ) 260 261 return expression 262 263 @classmethod 264 def case_sensitive(cls, text: str) -> bool: 265 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 266 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 267 return False 268 269 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 270 return any(unsafe(char) for char in text) 271 272 @classmethod 273 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 274 """Checks if text can be identified given an identify option. 275 276 Args: 277 text: The text to check. 278 identify: 279 "always" or `True`: Always returns true. 280 "safe": True if the identifier is case-insensitive. 281 282 Returns: 283 Whether or not the given text can be identified. 284 """ 285 if identify is True or identify == "always": 286 return True 287 288 if identify == "safe": 289 return not cls.case_sensitive(text) 290 291 return False 292 293 @classmethod 294 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 295 if isinstance(expression, exp.Identifier): 296 name = expression.this 297 expression.set( 298 "quoted", 299 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 300 ) 301 302 return expression 303 304 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 305 return self.parser(**opts).parse(self.tokenize(sql), sql) 306 307 def parse_into( 308 self, expression_type: exp.IntoType, sql: str, **opts 309 ) -> t.List[t.Optional[exp.Expression]]: 310 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 311 312 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 313 return self.generator(**opts).generate(expression) 314 315 def transpile(self, sql: str, **opts) -> t.List[str]: 316 return [self.generate(expression, **opts) for expression in self.parse(sql)] 317 318 def tokenize(self, sql: str) -> t.List[Token]: 319 return self.tokenizer.tokenize(sql) 320 321 @property 322 def tokenizer(self) -> Tokenizer: 323 if not hasattr(self, "_tokenizer"): 324 self._tokenizer = self.tokenizer_class() 325 return self._tokenizer 326 327 def parser(self, **opts) -> Parser: 328 return self.parser_class(**opts) 329 330 def generator(self, **opts) -> Generator: 331 return self.generator_class(**opts)
214 @classmethod 215 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 216 if not dialect: 217 return cls 218 if isinstance(dialect, _Dialect): 219 return dialect 220 if isinstance(dialect, Dialect): 221 return dialect.__class__ 222 223 result = cls.get(dialect) 224 if not result: 225 raise ValueError(f"Unknown dialect '{dialect}'") 226 227 return result
229 @classmethod 230 def format_time( 231 cls, expression: t.Optional[str | exp.Expression] 232 ) -> t.Optional[exp.Expression]: 233 if isinstance(expression, str): 234 return exp.Literal.string( 235 # the time formats are quoted 236 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 237 ) 238 239 if expression and expression.is_string: 240 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 241 242 return expression
244 @classmethod 245 def normalize_identifier(cls, expression: E) -> E: 246 """ 247 Normalizes an unquoted identifier to either lower or upper case, thus essentially 248 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 249 they will be normalized regardless of being quoted or not. 250 """ 251 if isinstance(expression, exp.Identifier) and ( 252 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 253 ): 254 expression.set( 255 "this", 256 expression.this.upper() 257 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 258 else expression.this.lower(), 259 ) 260 261 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized regardless of being quoted or not.
263 @classmethod 264 def case_sensitive(cls, text: str) -> bool: 265 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 266 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 267 return False 268 269 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 270 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
272 @classmethod 273 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 274 """Checks if text can be identified given an identify option. 275 276 Args: 277 text: The text to check. 278 identify: 279 "always" or `True`: Always returns true. 280 "safe": True if the identifier is case-insensitive. 281 282 Returns: 283 Whether or not the given text can be identified. 284 """ 285 if identify is True or identify == "always": 286 return True 287 288 if identify == "safe": 289 return not cls.case_sensitive(text) 290 291 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
293 @classmethod 294 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 295 if isinstance(expression, exp.Identifier): 296 name = expression.this 297 expression.set( 298 "quoted", 299 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 300 ) 301 302 return expression
347def if_sql( 348 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 349) -> t.Callable[[Generator, exp.If], str]: 350 def _if_sql(self: Generator, expression: exp.If) -> str: 351 return self.func( 352 name, 353 expression.this, 354 expression.args.get("true"), 355 expression.args.get("false") or false_value, 356 ) 357 358 return _if_sql
432def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 433 this = self.sql(expression, "this") 434 substr = self.sql(expression, "substr") 435 position = self.sql(expression, "position") 436 if position: 437 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 438 return f"STRPOS({this}, {substr})"
447def var_map_sql( 448 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 449) -> str: 450 keys = expression.args["keys"] 451 values = expression.args["values"] 452 453 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 454 self.unsupported("Cannot convert array columns into map.") 455 return self.func(map_func_name, keys, values) 456 457 args = [] 458 for key, value in zip(keys.expressions, values.expressions): 459 args.append(self.sql(key)) 460 args.append(self.sql(value)) 461 462 return self.func(map_func_name, *args)
465def format_time_lambda( 466 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 467) -> t.Callable[[t.List], E]: 468 """Helper used for time expressions. 469 470 Args: 471 exp_class: the expression class to instantiate. 472 dialect: target sql dialect. 473 default: the default format, True being time. 474 475 Returns: 476 A callable that can be used to return the appropriately formatted time expression. 477 """ 478 479 def _format_time(args: t.List): 480 return exp_class( 481 this=seq_get(args, 0), 482 format=Dialect[dialect].format_time( 483 seq_get(args, 1) 484 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 485 ), 486 ) 487 488 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
491def time_format( 492 dialect: DialectType = None, 493) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 494 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 495 """ 496 Returns the time format for a given expression, unless it's equivalent 497 to the default time format of the dialect of interest. 498 """ 499 time_format = self.format_time(expression) 500 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 501 502 return _time_format
505def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 506 """ 507 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 508 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 509 columns are removed from the create statement. 510 """ 511 has_schema = isinstance(expression.this, exp.Schema) 512 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 513 514 if has_schema and is_partitionable: 515 expression = expression.copy() 516 prop = expression.find(exp.PartitionedByProperty) 517 if prop and prop.this and not isinstance(prop.this, exp.Schema): 518 schema = expression.this 519 columns = {v.name.upper() for v in prop.this.expressions} 520 partitions = [col for col in schema.expressions if col.name.upper() in columns] 521 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 522 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 523 expression.set("this", schema) 524 525 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
528def parse_date_delta( 529 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 530) -> t.Callable[[t.List], E]: 531 def inner_func(args: t.List) -> E: 532 unit_based = len(args) == 3 533 this = args[2] if unit_based else seq_get(args, 0) 534 unit = args[0] if unit_based else exp.Literal.string("DAY") 535 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 536 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 537 538 return inner_func
541def parse_date_delta_with_interval( 542 expression_class: t.Type[E], 543) -> t.Callable[[t.List], t.Optional[E]]: 544 def func(args: t.List) -> t.Optional[E]: 545 if len(args) < 2: 546 return None 547 548 interval = args[1] 549 550 if not isinstance(interval, exp.Interval): 551 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 552 553 expression = interval.this 554 if expression and expression.is_string: 555 expression = exp.Literal.number(expression.this) 556 557 return expression_class( 558 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 559 ) 560 561 return func
573def date_add_interval_sql( 574 data_type: str, kind: str 575) -> t.Callable[[Generator, exp.Expression], str]: 576 def func(self: Generator, expression: exp.Expression) -> str: 577 this = self.sql(expression, "this") 578 unit = expression.args.get("unit") 579 unit = exp.var(unit.name.upper() if unit else "DAY") 580 interval = exp.Interval(this=expression.expression.copy(), unit=unit) 581 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 582 583 return func
632def encode_decode_sql( 633 self: Generator, expression: exp.Expression, name: str, replace: bool = True 634) -> str: 635 charset = expression.args.get("charset") 636 if charset and charset.name.lower() != "utf-8": 637 self.unsupported(f"Expected utf-8 character set, got {charset}.") 638 639 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
652def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 653 cond = expression.this 654 655 if isinstance(expression.this, exp.Distinct): 656 cond = expression.this.expressions[0] 657 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 658 659 return self.func("sum", exp.func("if", cond.copy(), 1, 0))
662def trim_sql(self: Generator, expression: exp.Trim) -> str: 663 target = self.sql(expression, "this") 664 trim_type = self.sql(expression, "position") 665 remove_chars = self.sql(expression, "expression") 666 collation = self.sql(expression, "collation") 667 668 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 669 if not remove_chars and not collation: 670 return self.trim_sql(expression) 671 672 trim_type = f"{trim_type} " if trim_type else "" 673 remove_chars = f"{remove_chars} " if remove_chars else "" 674 from_part = "FROM " if trim_type or remove_chars else "" 675 collation = f" COLLATE {collation}" if collation else "" 676 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
683def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 684 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 685 _dialect = Dialect.get_or_raise(dialect) 686 time_format = self.format_time(expression) 687 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 688 return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) 689 690 return self.sql(exp.cast(self.sql(expression, "this"), "date")) 691 692 return _ts_or_ds_to_date_sql
700def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 701 expression = expression.copy() 702 delim, *rest_args = expression.expressions 703 return self.sql( 704 reduce( 705 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 706 rest_args, 707 ) 708 )
711def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 712 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 713 if bad_args: 714 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 715 716 return self.func( 717 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 718 )
721def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 722 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 723 if bad_args: 724 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 725 726 return self.func( 727 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 728 )
731def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 732 names = [] 733 for agg in aggregations: 734 if isinstance(agg, exp.Alias): 735 names.append(agg.alias) 736 else: 737 """ 738 This case corresponds to aggregations without aliases being used as suffixes 739 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 740 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 741 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 742 """ 743 agg_all_unquoted = agg.transform( 744 lambda node: exp.Identifier(this=node.name, quoted=False) 745 if isinstance(node, exp.Identifier) 746 else node 747 ) 748 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 749 750 return names