sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
21class SparkSession: 22 DEFAULT_DIALECT = "spark" 23 _instance = None 24 25 def __init__(self): 26 if not hasattr(self, "known_ids"): 27 self.known_ids = set() 28 self.known_branch_ids = set() 29 self.known_sequence_ids = set() 30 self.name_to_sequence_id_mapping = defaultdict(list) 31 self.incrementing_id = 1 32 self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) 33 34 def __new__(cls, *args, **kwargs) -> SparkSession: 35 if cls._instance is None: 36 cls._instance = super().__new__(cls) 37 return cls._instance 38 39 @property 40 def read(self) -> DataFrameReader: 41 return DataFrameReader(self) 42 43 def table(self, tableName: str) -> DataFrame: 44 return self.read.table(tableName) 45 46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.tuple_( 75 *map( 76 lambda x: F.lit(x).expression, 77 row if not isinstance(row, dict) else row.values(), 78 ) 79 ) 80 for row in data 81 ] 82 83 sel_columns = [ 84 ( 85 F.col(name).cast(data_type).alias(name).expression 86 if data_type is not None 87 else F.col(name).expression 88 ) 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression) 107 108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df 126 127 @property 128 def _auto_incrementing_name(self) -> str: 129 name = f"a{self.incrementing_id}" 130 self.incrementing_id += 1 131 return name 132 133 @property 134 def _random_branch_id(self) -> str: 135 id = self._random_id 136 self.known_branch_ids.add(id) 137 return id 138 139 @property 140 def _random_sequence_id(self): 141 id = self._random_id 142 self.known_sequence_ids.add(id) 143 return id 144 145 @property 146 def _random_id(self) -> str: 147 id = "r" + uuid.uuid4().hex 148 self.known_ids.add(id) 149 return id 150 151 @property 152 def _join_hint_names(self) -> t.Set[str]: 153 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 154 155 def _add_alias_to_mapping(self, name: str, sequence_id: str): 156 self.name_to_sequence_id_mapping[name].append(sequence_id) 157 158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect) 187 return spark 188 189 @classproperty 190 def builder(cls) -> Builder: 191 return cls.Builder()
46 def createDataFrame( 47 self, 48 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 49 schema: t.Optional[SchemaInput] = None, 50 samplingRatio: t.Optional[float] = None, 51 verifySchema: bool = False, 52 ) -> DataFrame: 53 from sqlglot.dataframe.sql.dataframe import DataFrame 54 55 if samplingRatio is not None or verifySchema: 56 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 57 if schema is not None and ( 58 not isinstance(schema, (StructType, str, list)) 59 or (isinstance(schema, list) and not isinstance(schema[0], str)) 60 ): 61 raise NotImplementedError("Only schema of either list or string of list supported") 62 if not data: 63 raise ValueError("Must provide data to create into a DataFrame") 64 65 column_mapping: t.Dict[str, t.Optional[str]] 66 if schema is not None: 67 column_mapping = get_column_mapping_from_schema_input(schema) 68 elif isinstance(data[0], dict): 69 column_mapping = {col_name.strip(): None for col_name in data[0]} 70 else: 71 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 72 73 data_expressions = [ 74 exp.tuple_( 75 *map( 76 lambda x: F.lit(x).expression, 77 row if not isinstance(row, dict) else row.values(), 78 ) 79 ) 80 for row in data 81 ] 82 83 sel_columns = [ 84 ( 85 F.col(name).cast(data_type).alias(name).expression 86 if data_type is not None 87 else F.col(name).expression 88 ) 89 for name, data_type in column_mapping.items() 90 ] 91 92 select_kwargs = { 93 "expressions": sel_columns, 94 "from": exp.From( 95 this=exp.Values( 96 expressions=data_expressions, 97 alias=exp.TableAlias( 98 this=exp.to_identifier(self._auto_incrementing_name), 99 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 100 ), 101 ), 102 ), 103 } 104 105 sel_expression = exp.Select(**select_kwargs) 106 return DataFrame(self, sel_expression)
108 def sql(self, sqlQuery: str) -> DataFrame: 109 expression = sqlglot.parse_one(sqlQuery, read=self.dialect) 110 if isinstance(expression, exp.Select): 111 df = DataFrame(self, expression) 112 df = df._convert_leaf_to_cte() 113 elif isinstance(expression, (exp.Create, exp.Insert)): 114 select_expression = expression.expression.copy() 115 if isinstance(expression, exp.Insert): 116 select_expression.set("with", expression.args.get("with")) 117 expression.set("with", None) 118 del expression.args["expression"] 119 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 120 df = df._convert_leaf_to_cte() 121 else: 122 raise ValueError( 123 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 124 ) 125 return df
158 class Builder: 159 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 160 161 def __init__(self): 162 self.dialect = "spark" 163 164 def __getattr__(self, item) -> SparkSession.Builder: 165 return self 166 167 def __call__(self, *args, **kwargs): 168 return self 169 170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self 183 184 def getOrCreate(self) -> SparkSession: 185 spark = SparkSession() 186 spark.dialect = Dialect.get_or_raise(self.dialect) 187 return spark
170 def config( 171 self, 172 key: t.Optional[str] = None, 173 value: t.Optional[t.Any] = None, 174 *, 175 map: t.Optional[t.Dict[str, t.Any]] = None, 176 **kwargs: t.Any, 177 ) -> SparkSession.Builder: 178 if key == self.SQLFRAME_DIALECT_KEY: 179 self.dialect = value 180 elif map and self.SQLFRAME_DIALECT_KEY in map: 181 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 182 return self
49class DataFrame: 50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select() 68 69 def __getattr__(self, column_name: str) -> Column: 70 return self[column_name] 71 72 def __getitem__(self, column_name: str) -> Column: 73 column_name = f"{self.branch_id}.{column_name}" 74 return Column(column_name) 75 76 def __copy__(self): 77 return self.copy() 78 79 @property 80 def sparkSession(self): 81 return self.spark 82 83 @property 84 def write(self): 85 return DataFrameWriter(self) 86 87 @property 88 def latest_cte_name(self) -> str: 89 if not self.expression.ctes: 90 from_exp = self.expression.args["from"] 91 if from_exp.alias_or_name: 92 return from_exp.alias_or_name 93 table_alias = from_exp.find(exp.TableAlias) 94 if not table_alias: 95 raise RuntimeError( 96 f"Could not find an alias name for this expression: {self.expression}" 97 ) 98 return table_alias.alias_or_name 99 return self.expression.ctes[-1].alias 100 101 @property 102 def pending_join_hints(self): 103 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 104 105 @property 106 def pending_partition_hints(self): 107 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 108 109 @property 110 def columns(self) -> t.List[str]: 111 return self.expression.named_selects 112 113 @property 114 def na(self) -> DataFrameNaFunctions: 115 return DataFrameNaFunctions(self) 116 117 def _replace_cte_names_with_hashes(self, expression: exp.Select): 118 replacement_mapping = {} 119 for cte in expression.ctes: 120 old_name_id = cte.args["alias"].this 121 new_hashed_id = exp.to_identifier( 122 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 123 ) 124 replacement_mapping[old_name_id] = new_hashed_id 125 expression = expression.transform(replace_id_value, replacement_mapping) 126 return expression 127 128 def _create_cte_from_expression( 129 self, 130 expression: exp.Expression, 131 branch_id: t.Optional[str] = None, 132 sequence_id: t.Optional[str] = None, 133 **kwargs, 134 ) -> t.Tuple[exp.CTE, str]: 135 name = self._create_hash_from_expression(expression) 136 expression_to_cte = expression.copy() 137 expression_to_cte.set("with", None) 138 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 139 cte.set("branch_id", branch_id or self.branch_id) 140 cte.set("sequence_id", sequence_id or self.sequence_id) 141 return cte, name 142 143 @t.overload 144 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... 145 146 @t.overload 147 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... 148 149 def _ensure_list_of_columns(self, cols): 150 return Column.ensure_cols(ensure_list(cols)) 151 152 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 153 cols = self._ensure_list_of_columns(cols) 154 normalize(self.spark, expression or self.expression, cols) 155 return cols 156 157 def _ensure_and_normalize_col(self, col): 158 col = Column.ensure_col(col) 159 normalize(self.spark, self.expression, col) 160 return col 161 162 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 163 df = self._resolve_pending_hints() 164 sequence_id = sequence_id or df.sequence_id 165 expression = df.expression.copy() 166 cte_expression, cte_name = df._create_cte_from_expression( 167 expression=expression, sequence_id=sequence_id 168 ) 169 new_expression = df._add_ctes_to_expression( 170 exp.Select(), expression.ctes + [cte_expression] 171 ) 172 sel_columns = df._get_outer_select_columns(cte_expression) 173 new_expression = new_expression.from_(cte_name).select( 174 *[x.alias_or_name for x in sel_columns] 175 ) 176 return df.copy(expression=new_expression, sequence_id=sequence_id) 177 178 def _resolve_pending_hints(self) -> DataFrame: 179 df = self.copy() 180 if not self.pending_hints: 181 return df 182 expression = df.expression 183 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 184 for hint in df.pending_partition_hints: 185 hint_expression.append("expressions", hint) 186 df.pending_hints.remove(hint) 187 188 join_aliases = { 189 join_table.alias_or_name 190 for join_table in get_tables_from_expression_with_join(expression) 191 } 192 if join_aliases: 193 for hint in df.pending_join_hints: 194 for sequence_id_expression in hint.expressions: 195 sequence_id_or_name = sequence_id_expression.alias_or_name 196 sequence_ids_to_match = [sequence_id_or_name] 197 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 198 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 199 sequence_id_or_name 200 ] 201 matching_ctes = [ 202 cte 203 for cte in reversed(expression.ctes) 204 if cte.args["sequence_id"] in sequence_ids_to_match 205 ] 206 for matching_cte in matching_ctes: 207 if matching_cte.alias_or_name in join_aliases: 208 sequence_id_expression.set("this", matching_cte.args["alias"].this) 209 df.pending_hints.remove(hint) 210 break 211 hint_expression.append("expressions", hint) 212 if hint_expression.expressions: 213 expression.set("hint", hint_expression) 214 return df 215 216 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 217 hint_name = hint_name.upper() 218 hint_expression = ( 219 exp.JoinHint( 220 this=hint_name, 221 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 222 ) 223 if hint_name in JOIN_HINTS 224 else exp.Anonymous( 225 this=hint_name, expressions=[parameter.expression for parameter in args] 226 ) 227 ) 228 new_df = self.copy() 229 new_df.pending_hints.append(hint_expression) 230 return new_df 231 232 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 233 other_df = other._convert_leaf_to_cte() 234 base_expression = self.expression.copy() 235 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 236 all_ctes = base_expression.ctes 237 other_df.expression.set("with", None) 238 base_expression.set("with", None) 239 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 240 operation.set("with", exp.With(expressions=all_ctes)) 241 return self.copy(expression=operation)._convert_leaf_to_cte() 242 243 def _cache(self, storage_level: str): 244 df = self._convert_leaf_to_cte() 245 df.expression.ctes[-1].set("cache_storage_level", storage_level) 246 return df 247 248 @classmethod 249 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 250 expression = expression.copy() 251 with_expression = expression.args.get("with") 252 if with_expression: 253 existing_ctes = with_expression.expressions 254 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 255 for cte in ctes: 256 if cte.alias_or_name not in existsing_cte_names: 257 existing_ctes.append(cte) 258 else: 259 existing_ctes = ctes 260 expression.set("with", exp.With(expressions=existing_ctes)) 261 return expression 262 263 @classmethod 264 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 265 expression = item.expression if isinstance(item, DataFrame) else item 266 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 267 268 @classmethod 269 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 270 from sqlglot.dataframe.sql.session import SparkSession 271 272 value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") 273 return f"t{zlib.crc32(value)}"[:6] 274 275 def _get_select_expressions( 276 self, 277 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 278 select_expressions: t.List[ 279 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 280 ] = [] 281 main_select_ctes: t.List[exp.CTE] = [] 282 for cte in self.expression.ctes: 283 cache_storage_level = cte.args.get("cache_storage_level") 284 if cache_storage_level: 285 select_expression = cte.this.copy() 286 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 287 select_expression.set("cte_alias_name", cte.alias_or_name) 288 select_expression.set("cache_storage_level", cache_storage_level) 289 select_expressions.append((exp.Cache, select_expression)) 290 else: 291 main_select_ctes.append(cte) 292 main_select = self.expression.copy() 293 if main_select_ctes: 294 main_select.set("with", exp.With(expressions=main_select_ctes)) 295 expression_select_pair = (type(self.output_expression_container), main_select) 296 select_expressions.append(expression_select_pair) # type: ignore 297 return select_expressions 298 299 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 300 from sqlglot.dataframe.sql.session import SparkSession 301 302 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 303 304 df = self._resolve_pending_hints() 305 select_expressions = df._get_select_expressions() 306 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 307 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 308 309 for expression_type, select_expression in select_expressions: 310 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 311 if optimize: 312 quote_identifiers(select_expression, dialect=dialect) 313 select_expression = t.cast( 314 exp.Select, optimize_func(select_expression, dialect=dialect) 315 ) 316 317 select_expression = df._replace_cte_names_with_hashes(select_expression) 318 319 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 320 if expression_type == exp.Cache: 321 cache_table_name = df._create_hash_from_expression(select_expression) 322 cache_table = exp.to_table(cache_table_name) 323 original_alias_name = select_expression.args["cte_alias_name"] 324 325 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 326 cache_table_name 327 ) 328 sqlglot.schema.add_table( 329 cache_table_name, 330 { 331 expression.alias_or_name: expression.type.sql(dialect=dialect) 332 for expression in select_expression.expressions 333 }, 334 dialect=dialect, 335 ) 336 337 cache_storage_level = select_expression.args["cache_storage_level"] 338 options = [ 339 exp.Literal.string("storageLevel"), 340 exp.Literal.string(cache_storage_level), 341 ] 342 expression = exp.Cache( 343 this=cache_table, expression=select_expression, lazy=True, options=options 344 ) 345 346 # We will drop the "view" if it exists before running the cache table 347 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 348 elif expression_type == exp.Create: 349 expression = df.output_expression_container.copy() 350 expression.set("expression", select_expression) 351 elif expression_type == exp.Insert: 352 expression = df.output_expression_container.copy() 353 select_without_ctes = select_expression.copy() 354 select_without_ctes.set("with", None) 355 expression.set("expression", select_without_ctes) 356 357 if select_expression.ctes: 358 expression.set("with", exp.With(expressions=select_expression.ctes)) 359 elif expression_type == exp.Select: 360 expression = select_expression 361 else: 362 raise ValueError(f"Invalid expression type: {expression_type}") 363 364 output_expressions.append(expression) 365 366 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] 367 368 def copy(self, **kwargs) -> DataFrame: 369 return DataFrame(**object_to_dict(self, **kwargs)) 370 371 @operation(Operation.SELECT) 372 def select(self, *cols, **kwargs) -> DataFrame: 373 cols = self._ensure_and_normalize_cols(cols) 374 kwargs["append"] = kwargs.get("append", False) 375 if self.expression.args.get("joins"): 376 ambiguous_cols = [ 377 col 378 for col in cols 379 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 380 ] 381 if ambiguous_cols: 382 join_table_identifiers = [ 383 x.this for x in get_tables_from_expression_with_join(self.expression) 384 ] 385 cte_names_in_join = [x.this for x in join_table_identifiers] 386 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 387 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 388 # of Spark. 389 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 390 for ambiguous_col in ambiguous_cols: 391 ctes_with_column = [ 392 cte 393 for cte in self.expression.ctes 394 if cte.alias_or_name in cte_names_in_join 395 and ambiguous_col.alias_or_name in cte.this.named_selects 396 ] 397 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 398 # use the same CTE we used before 399 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 400 if cte: 401 resolved_column_position[ambiguous_col] += 1 402 else: 403 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 404 ambiguous_col.expression.set("table", cte.alias_or_name) 405 return self.copy( 406 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 407 ) 408 409 @operation(Operation.NO_OP) 410 def alias(self, name: str, **kwargs) -> DataFrame: 411 new_sequence_id = self.spark._random_sequence_id 412 df = self.copy() 413 for join_hint in df.pending_join_hints: 414 for expression in join_hint.expressions: 415 if expression.alias_or_name == self.sequence_id: 416 expression.set("this", Column.ensure_col(new_sequence_id).expression) 417 df.spark._add_alias_to_mapping(name, new_sequence_id) 418 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 419 420 @operation(Operation.WHERE) 421 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 422 col = self._ensure_and_normalize_col(column) 423 return self.copy(expression=self.expression.where(col.expression)) 424 425 filter = where 426 427 @operation(Operation.GROUP_BY) 428 def groupBy(self, *cols, **kwargs) -> GroupedData: 429 columns = self._ensure_and_normalize_cols(cols) 430 return GroupedData(self, columns, self.last_op) 431 432 @operation(Operation.SELECT) 433 def agg(self, *exprs, **kwargs) -> DataFrame: 434 cols = self._ensure_and_normalize_cols(exprs) 435 return self.groupBy().agg(*cols) 436 437 @operation(Operation.FROM) 438 def join( 439 self, 440 other_df: DataFrame, 441 on: t.Union[str, t.List[str], Column, t.List[Column]], 442 how: str = "inner", 443 **kwargs, 444 ) -> DataFrame: 445 other_df = other_df._convert_leaf_to_cte() 446 join_columns = self._ensure_list_of_columns(on) 447 # We will determine actual "join on" expression later so we don't provide it at first 448 join_expression = self.expression.join( 449 other_df.latest_cte_name, join_type=how.replace("_", " ") 450 ) 451 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 452 self_columns = self._get_outer_select_columns(join_expression) 453 other_columns = self._get_outer_select_columns(other_df) 454 # Determines the join clause and select columns to be used passed on what type of columns were provided for 455 # the join. The columns returned changes based on how the on expression is provided. 456 if isinstance(join_columns[0].expression, exp.Column): 457 """ 458 Unique characteristics of join on column names only: 459 * The column names are put at the front of the select list 460 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 461 """ 462 table_names = [ 463 table.alias_or_name 464 for table in get_tables_from_expression_with_join(join_expression) 465 ] 466 potential_ctes = [ 467 cte 468 for cte in join_expression.ctes 469 if cte.alias_or_name in table_names 470 and cte.alias_or_name != other_df.latest_cte_name 471 ] 472 # Determine the table to reference for the left side of the join by checking each of the left side 473 # tables and see if they have the column being referenced. 474 join_column_pairs = [] 475 for join_column in join_columns: 476 num_matching_ctes = 0 477 for cte in potential_ctes: 478 if join_column.alias_or_name in cte.this.named_selects: 479 left_column = join_column.copy().set_table_name(cte.alias_or_name) 480 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 481 join_column_pairs.append((left_column, right_column)) 482 num_matching_ctes += 1 483 if num_matching_ctes > 1: 484 raise ValueError( 485 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 486 ) 487 elif num_matching_ctes == 0: 488 raise ValueError( 489 f"Column {join_column.alias_or_name} does not exist in any of the tables." 490 ) 491 join_clause = functools.reduce( 492 lambda x, y: x & y, 493 [left_column == right_column for left_column, right_column in join_column_pairs], 494 ) 495 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 496 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 497 select_column_names = [ 498 ( 499 column.alias_or_name 500 if not isinstance(column.expression.this, exp.Star) 501 else column.sql() 502 ) 503 for column in self_columns + other_columns 504 ] 505 select_column_names = [ 506 column_name 507 for column_name in select_column_names 508 if column_name not in join_column_names 509 ] 510 select_column_names = join_column_names + select_column_names 511 else: 512 """ 513 Unique characteristics of join on expressions: 514 * There is no deduplication of the results. 515 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 516 """ 517 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 518 if len(join_columns) > 1: 519 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 520 join_clause = join_columns[0] 521 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 522 523 # Update the on expression with the actual join clause to replace the dummy one from before 524 join_expression.args["joins"][-1].set("on", join_clause.expression) 525 new_df = self.copy(expression=join_expression) 526 new_df.pending_join_hints.extend(self.pending_join_hints) 527 new_df.pending_hints.extend(other_df.pending_hints) 528 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 529 return new_df 530 531 @operation(Operation.ORDER_BY) 532 def orderBy( 533 self, 534 *cols: t.Union[str, Column], 535 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 536 ) -> DataFrame: 537 """ 538 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 539 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 540 is unlikely to come up. 541 """ 542 columns = self._ensure_and_normalize_cols(cols) 543 pre_ordered_col_indexes = [ 544 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 545 ] 546 if ascending is None: 547 ascending = [True] * len(columns) 548 elif not isinstance(ascending, list): 549 ascending = [ascending] * len(columns) 550 ascending = [bool(x) for i, x in enumerate(ascending)] 551 assert len(columns) == len( 552 ascending 553 ), "The length of items in ascending must equal the number of columns provided" 554 col_and_ascending = list(zip(columns, ascending)) 555 order_by_columns = [ 556 ( 557 exp.Ordered(this=col.expression, desc=not asc) 558 if i not in pre_ordered_col_indexes 559 else columns[i].column_expression 560 ) 561 for i, (col, asc) in enumerate(col_and_ascending) 562 ] 563 return self.copy(expression=self.expression.order_by(*order_by_columns)) 564 565 sort = orderBy 566 567 @operation(Operation.FROM) 568 def union(self, other: DataFrame) -> DataFrame: 569 return self._set_operation(exp.Union, other, False) 570 571 unionAll = union 572 573 @operation(Operation.FROM) 574 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 575 l_columns = self.columns 576 r_columns = other.columns 577 if not allowMissingColumns: 578 l_expressions = l_columns 579 r_expressions = l_columns 580 else: 581 l_expressions = [] 582 r_expressions = [] 583 r_columns_unused = copy(r_columns) 584 for l_column in l_columns: 585 l_expressions.append(l_column) 586 if l_column in r_columns: 587 r_expressions.append(l_column) 588 r_columns_unused.remove(l_column) 589 else: 590 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 591 for r_column in r_columns_unused: 592 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 593 r_expressions.append(r_column) 594 r_df = ( 595 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 596 ) 597 l_df = self.copy() 598 if allowMissingColumns: 599 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 600 return l_df._set_operation(exp.Union, r_df, False) 601 602 @operation(Operation.FROM) 603 def intersect(self, other: DataFrame) -> DataFrame: 604 return self._set_operation(exp.Intersect, other, True) 605 606 @operation(Operation.FROM) 607 def intersectAll(self, other: DataFrame) -> DataFrame: 608 return self._set_operation(exp.Intersect, other, False) 609 610 @operation(Operation.FROM) 611 def exceptAll(self, other: DataFrame) -> DataFrame: 612 return self._set_operation(exp.Except, other, False) 613 614 @operation(Operation.SELECT) 615 def distinct(self) -> DataFrame: 616 return self.copy(expression=self.expression.distinct()) 617 618 @operation(Operation.SELECT) 619 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 620 if not subset: 621 return self.distinct() 622 column_names = ensure_list(subset) 623 window = Window.partitionBy(*column_names).orderBy(*column_names) 624 return ( 625 self.copy() 626 .withColumn("row_num", F.row_number().over(window)) 627 .where(F.col("row_num") == F.lit(1)) 628 .drop("row_num") 629 ) 630 631 @operation(Operation.FROM) 632 def dropna( 633 self, 634 how: str = "any", 635 thresh: t.Optional[int] = None, 636 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 637 ) -> DataFrame: 638 minimum_non_null = thresh or 0 # will be determined later if thresh is null 639 new_df = self.copy() 640 all_columns = self._get_outer_select_columns(new_df.expression) 641 if subset: 642 null_check_columns = self._ensure_and_normalize_cols(subset) 643 else: 644 null_check_columns = all_columns 645 if thresh is None: 646 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 647 else: 648 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 649 if minimum_num_nulls > len(null_check_columns): 650 raise RuntimeError( 651 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 652 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 653 ) 654 if_null_checks = [ 655 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 656 ] 657 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 658 num_nulls = nulls_added_together.alias("num_nulls") 659 new_df = new_df.select(num_nulls, append=True) 660 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 661 final_df = filtered_df.select(*all_columns) 662 return final_df 663 664 @operation(Operation.FROM) 665 def fillna( 666 self, 667 value: t.Union[ColumnLiterals], 668 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 669 ) -> DataFrame: 670 """ 671 Functionality Difference: If you provide a value to replace a null and that type conflicts 672 with the type of the column then PySpark will just ignore your replacement. 673 This will try to cast them to be the same in some cases. So they won't always match. 674 Best to not mix types so make sure replacement is the same type as the column 675 676 Possibility for improvement: Use `typeof` function to get the type of the column 677 and check if it matches the type of the value provided. If not then make it null. 678 """ 679 from sqlglot.dataframe.sql.functions import lit 680 681 values = None 682 columns = None 683 new_df = self.copy() 684 all_columns = self._get_outer_select_columns(new_df.expression) 685 all_column_mapping = {column.alias_or_name: column for column in all_columns} 686 if isinstance(value, dict): 687 values = list(value.values()) 688 columns = self._ensure_and_normalize_cols(list(value)) 689 if not columns: 690 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 691 if not values: 692 values = [value] * len(columns) 693 value_columns = [lit(value) for value in values] 694 695 null_replacement_mapping = { 696 column.alias_or_name: ( 697 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 698 ) 699 for column, value in zip(columns, value_columns) 700 } 701 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 702 null_replacement_columns = [ 703 null_replacement_mapping[column.alias_or_name] for column in all_columns 704 ] 705 new_df = new_df.select(*null_replacement_columns) 706 return new_df 707 708 @operation(Operation.FROM) 709 def replace( 710 self, 711 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 712 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 713 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 714 ) -> DataFrame: 715 from sqlglot.dataframe.sql.functions import lit 716 717 old_values = None 718 new_df = self.copy() 719 all_columns = self._get_outer_select_columns(new_df.expression) 720 all_column_mapping = {column.alias_or_name: column for column in all_columns} 721 722 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 723 if isinstance(to_replace, dict): 724 old_values = list(to_replace) 725 new_values = list(to_replace.values()) 726 elif not old_values and isinstance(to_replace, list): 727 assert isinstance(value, list), "value must be a list since the replacements are a list" 728 assert len(to_replace) == len( 729 value 730 ), "the replacements and values must be the same length" 731 old_values = to_replace 732 new_values = value 733 else: 734 old_values = [to_replace] * len(columns) 735 new_values = [value] * len(columns) 736 old_values = [lit(value) for value in old_values] 737 new_values = [lit(value) for value in new_values] 738 739 replacement_mapping = {} 740 for column in columns: 741 expression = Column(None) 742 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 743 if i == 0: 744 expression = F.when(column == old_value, new_value) 745 else: 746 expression = expression.when(column == old_value, new_value) # type: ignore 747 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 748 column.expression.alias_or_name 749 ) 750 751 replacement_mapping = {**all_column_mapping, **replacement_mapping} 752 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 753 new_df = new_df.select(*replacement_columns) 754 return new_df 755 756 @operation(Operation.SELECT) 757 def withColumn(self, colName: str, col: Column) -> DataFrame: 758 col = self._ensure_and_normalize_col(col) 759 existing_col_names = self.expression.named_selects 760 existing_col_index = ( 761 existing_col_names.index(colName) if colName in existing_col_names else None 762 ) 763 if existing_col_index: 764 expression = self.expression.copy() 765 expression.expressions[existing_col_index] = col.expression 766 return self.copy(expression=expression) 767 return self.copy().select(col.alias(colName), append=True) 768 769 @operation(Operation.SELECT) 770 def withColumnRenamed(self, existing: str, new: str): 771 expression = self.expression.copy() 772 existing_columns = [ 773 expression 774 for expression in expression.expressions 775 if expression.alias_or_name == existing 776 ] 777 if not existing_columns: 778 raise ValueError("Tried to rename a column that doesn't exist") 779 for existing_column in existing_columns: 780 if isinstance(existing_column, exp.Column): 781 existing_column.replace(exp.alias_(existing_column, new)) 782 else: 783 existing_column.set("alias", exp.to_identifier(new)) 784 return self.copy(expression=expression) 785 786 @operation(Operation.SELECT) 787 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 788 all_columns = self._get_outer_select_columns(self.expression) 789 drop_cols = self._ensure_and_normalize_cols(cols) 790 new_columns = [ 791 col 792 for col in all_columns 793 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 794 ] 795 return self.copy().select(*new_columns, append=False) 796 797 @operation(Operation.LIMIT) 798 def limit(self, num: int) -> DataFrame: 799 return self.copy(expression=self.expression.limit(num)) 800 801 @operation(Operation.NO_OP) 802 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 803 parameter_list = ensure_list(parameters) 804 parameter_columns = ( 805 self._ensure_list_of_columns(parameter_list) 806 if parameters 807 else Column.ensure_cols([self.sequence_id]) 808 ) 809 return self._hint(name, parameter_columns) 810 811 @operation(Operation.NO_OP) 812 def repartition( 813 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 814 ) -> DataFrame: 815 num_partition_cols = self._ensure_list_of_columns(numPartitions) 816 columns = self._ensure_and_normalize_cols(cols) 817 args = num_partition_cols + columns 818 return self._hint("repartition", args) 819 820 @operation(Operation.NO_OP) 821 def coalesce(self, numPartitions: int) -> DataFrame: 822 num_partitions = Column.ensure_cols([numPartitions]) 823 return self._hint("coalesce", num_partitions) 824 825 @operation(Operation.NO_OP) 826 def cache(self) -> DataFrame: 827 return self._cache(storage_level="MEMORY_AND_DISK") 828 829 @operation(Operation.NO_OP) 830 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 831 """ 832 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 833 """ 834 return self._cache(storageLevel)
50 def __init__( 51 self, 52 spark: SparkSession, 53 expression: exp.Select, 54 branch_id: t.Optional[str] = None, 55 sequence_id: t.Optional[str] = None, 56 last_op: Operation = Operation.INIT, 57 pending_hints: t.Optional[t.List[exp.Expression]] = None, 58 output_expression_container: t.Optional[OutputExpressionContainer] = None, 59 **kwargs, 60 ): 61 self.spark = spark 62 self.expression = expression 63 self.branch_id = branch_id or self.spark._random_branch_id 64 self.sequence_id = sequence_id or self.spark._random_sequence_id 65 self.last_op = last_op 66 self.pending_hints = pending_hints or [] 67 self.output_expression_container = output_expression_container or exp.Select()
87 @property 88 def latest_cte_name(self) -> str: 89 if not self.expression.ctes: 90 from_exp = self.expression.args["from"] 91 if from_exp.alias_or_name: 92 return from_exp.alias_or_name 93 table_alias = from_exp.find(exp.TableAlias) 94 if not table_alias: 95 raise RuntimeError( 96 f"Could not find an alias name for this expression: {self.expression}" 97 ) 98 return table_alias.alias_or_name 99 return self.expression.ctes[-1].alias
299 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 300 from sqlglot.dataframe.sql.session import SparkSession 301 302 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 303 304 df = self._resolve_pending_hints() 305 select_expressions = df._get_select_expressions() 306 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 307 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 308 309 for expression_type, select_expression in select_expressions: 310 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 311 if optimize: 312 quote_identifiers(select_expression, dialect=dialect) 313 select_expression = t.cast( 314 exp.Select, optimize_func(select_expression, dialect=dialect) 315 ) 316 317 select_expression = df._replace_cte_names_with_hashes(select_expression) 318 319 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 320 if expression_type == exp.Cache: 321 cache_table_name = df._create_hash_from_expression(select_expression) 322 cache_table = exp.to_table(cache_table_name) 323 original_alias_name = select_expression.args["cte_alias_name"] 324 325 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 326 cache_table_name 327 ) 328 sqlglot.schema.add_table( 329 cache_table_name, 330 { 331 expression.alias_or_name: expression.type.sql(dialect=dialect) 332 for expression in select_expression.expressions 333 }, 334 dialect=dialect, 335 ) 336 337 cache_storage_level = select_expression.args["cache_storage_level"] 338 options = [ 339 exp.Literal.string("storageLevel"), 340 exp.Literal.string(cache_storage_level), 341 ] 342 expression = exp.Cache( 343 this=cache_table, expression=select_expression, lazy=True, options=options 344 ) 345 346 # We will drop the "view" if it exists before running the cache table 347 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 348 elif expression_type == exp.Create: 349 expression = df.output_expression_container.copy() 350 expression.set("expression", select_expression) 351 elif expression_type == exp.Insert: 352 expression = df.output_expression_container.copy() 353 select_without_ctes = select_expression.copy() 354 select_without_ctes.set("with", None) 355 expression.set("expression", select_without_ctes) 356 357 if select_expression.ctes: 358 expression.set("with", exp.With(expressions=select_expression.ctes)) 359 elif expression_type == exp.Select: 360 expression = select_expression 361 else: 362 raise ValueError(f"Invalid expression type: {expression_type}") 363 364 output_expressions.append(expression) 365 366 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
371 @operation(Operation.SELECT) 372 def select(self, *cols, **kwargs) -> DataFrame: 373 cols = self._ensure_and_normalize_cols(cols) 374 kwargs["append"] = kwargs.get("append", False) 375 if self.expression.args.get("joins"): 376 ambiguous_cols = [ 377 col 378 for col in cols 379 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 380 ] 381 if ambiguous_cols: 382 join_table_identifiers = [ 383 x.this for x in get_tables_from_expression_with_join(self.expression) 384 ] 385 cte_names_in_join = [x.this for x in join_table_identifiers] 386 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 387 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 388 # of Spark. 389 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 390 for ambiguous_col in ambiguous_cols: 391 ctes_with_column = [ 392 cte 393 for cte in self.expression.ctes 394 if cte.alias_or_name in cte_names_in_join 395 and ambiguous_col.alias_or_name in cte.this.named_selects 396 ] 397 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 398 # use the same CTE we used before 399 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 400 if cte: 401 resolved_column_position[ambiguous_col] += 1 402 else: 403 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 404 ambiguous_col.expression.set("table", cte.alias_or_name) 405 return self.copy( 406 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 407 )
409 @operation(Operation.NO_OP) 410 def alias(self, name: str, **kwargs) -> DataFrame: 411 new_sequence_id = self.spark._random_sequence_id 412 df = self.copy() 413 for join_hint in df.pending_join_hints: 414 for expression in join_hint.expressions: 415 if expression.alias_or_name == self.sequence_id: 416 expression.set("this", Column.ensure_col(new_sequence_id).expression) 417 df.spark._add_alias_to_mapping(name, new_sequence_id) 418 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
437 @operation(Operation.FROM) 438 def join( 439 self, 440 other_df: DataFrame, 441 on: t.Union[str, t.List[str], Column, t.List[Column]], 442 how: str = "inner", 443 **kwargs, 444 ) -> DataFrame: 445 other_df = other_df._convert_leaf_to_cte() 446 join_columns = self._ensure_list_of_columns(on) 447 # We will determine actual "join on" expression later so we don't provide it at first 448 join_expression = self.expression.join( 449 other_df.latest_cte_name, join_type=how.replace("_", " ") 450 ) 451 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 452 self_columns = self._get_outer_select_columns(join_expression) 453 other_columns = self._get_outer_select_columns(other_df) 454 # Determines the join clause and select columns to be used passed on what type of columns were provided for 455 # the join. The columns returned changes based on how the on expression is provided. 456 if isinstance(join_columns[0].expression, exp.Column): 457 """ 458 Unique characteristics of join on column names only: 459 * The column names are put at the front of the select list 460 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 461 """ 462 table_names = [ 463 table.alias_or_name 464 for table in get_tables_from_expression_with_join(join_expression) 465 ] 466 potential_ctes = [ 467 cte 468 for cte in join_expression.ctes 469 if cte.alias_or_name in table_names 470 and cte.alias_or_name != other_df.latest_cte_name 471 ] 472 # Determine the table to reference for the left side of the join by checking each of the left side 473 # tables and see if they have the column being referenced. 474 join_column_pairs = [] 475 for join_column in join_columns: 476 num_matching_ctes = 0 477 for cte in potential_ctes: 478 if join_column.alias_or_name in cte.this.named_selects: 479 left_column = join_column.copy().set_table_name(cte.alias_or_name) 480 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 481 join_column_pairs.append((left_column, right_column)) 482 num_matching_ctes += 1 483 if num_matching_ctes > 1: 484 raise ValueError( 485 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 486 ) 487 elif num_matching_ctes == 0: 488 raise ValueError( 489 f"Column {join_column.alias_or_name} does not exist in any of the tables." 490 ) 491 join_clause = functools.reduce( 492 lambda x, y: x & y, 493 [left_column == right_column for left_column, right_column in join_column_pairs], 494 ) 495 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 496 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 497 select_column_names = [ 498 ( 499 column.alias_or_name 500 if not isinstance(column.expression.this, exp.Star) 501 else column.sql() 502 ) 503 for column in self_columns + other_columns 504 ] 505 select_column_names = [ 506 column_name 507 for column_name in select_column_names 508 if column_name not in join_column_names 509 ] 510 select_column_names = join_column_names + select_column_names 511 else: 512 """ 513 Unique characteristics of join on expressions: 514 * There is no deduplication of the results. 515 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 516 """ 517 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 518 if len(join_columns) > 1: 519 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 520 join_clause = join_columns[0] 521 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 522 523 # Update the on expression with the actual join clause to replace the dummy one from before 524 join_expression.args["joins"][-1].set("on", join_clause.expression) 525 new_df = self.copy(expression=join_expression) 526 new_df.pending_join_hints.extend(self.pending_join_hints) 527 new_df.pending_hints.extend(other_df.pending_hints) 528 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 529 return new_df
531 @operation(Operation.ORDER_BY) 532 def orderBy( 533 self, 534 *cols: t.Union[str, Column], 535 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 536 ) -> DataFrame: 537 """ 538 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 539 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 540 is unlikely to come up. 541 """ 542 columns = self._ensure_and_normalize_cols(cols) 543 pre_ordered_col_indexes = [ 544 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 545 ] 546 if ascending is None: 547 ascending = [True] * len(columns) 548 elif not isinstance(ascending, list): 549 ascending = [ascending] * len(columns) 550 ascending = [bool(x) for i, x in enumerate(ascending)] 551 assert len(columns) == len( 552 ascending 553 ), "The length of items in ascending must equal the number of columns provided" 554 col_and_ascending = list(zip(columns, ascending)) 555 order_by_columns = [ 556 ( 557 exp.Ordered(this=col.expression, desc=not asc) 558 if i not in pre_ordered_col_indexes 559 else columns[i].column_expression 560 ) 561 for i, (col, asc) in enumerate(col_and_ascending) 562 ] 563 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
531 @operation(Operation.ORDER_BY) 532 def orderBy( 533 self, 534 *cols: t.Union[str, Column], 535 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 536 ) -> DataFrame: 537 """ 538 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 539 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 540 is unlikely to come up. 541 """ 542 columns = self._ensure_and_normalize_cols(cols) 543 pre_ordered_col_indexes = [ 544 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 545 ] 546 if ascending is None: 547 ascending = [True] * len(columns) 548 elif not isinstance(ascending, list): 549 ascending = [ascending] * len(columns) 550 ascending = [bool(x) for i, x in enumerate(ascending)] 551 assert len(columns) == len( 552 ascending 553 ), "The length of items in ascending must equal the number of columns provided" 554 col_and_ascending = list(zip(columns, ascending)) 555 order_by_columns = [ 556 ( 557 exp.Ordered(this=col.expression, desc=not asc) 558 if i not in pre_ordered_col_indexes 559 else columns[i].column_expression 560 ) 561 for i, (col, asc) in enumerate(col_and_ascending) 562 ] 563 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
573 @operation(Operation.FROM) 574 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 575 l_columns = self.columns 576 r_columns = other.columns 577 if not allowMissingColumns: 578 l_expressions = l_columns 579 r_expressions = l_columns 580 else: 581 l_expressions = [] 582 r_expressions = [] 583 r_columns_unused = copy(r_columns) 584 for l_column in l_columns: 585 l_expressions.append(l_column) 586 if l_column in r_columns: 587 r_expressions.append(l_column) 588 r_columns_unused.remove(l_column) 589 else: 590 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 591 for r_column in r_columns_unused: 592 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 593 r_expressions.append(r_column) 594 r_df = ( 595 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 596 ) 597 l_df = self.copy() 598 if allowMissingColumns: 599 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 600 return l_df._set_operation(exp.Union, r_df, False)
618 @operation(Operation.SELECT) 619 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 620 if not subset: 621 return self.distinct() 622 column_names = ensure_list(subset) 623 window = Window.partitionBy(*column_names).orderBy(*column_names) 624 return ( 625 self.copy() 626 .withColumn("row_num", F.row_number().over(window)) 627 .where(F.col("row_num") == F.lit(1)) 628 .drop("row_num") 629 )
631 @operation(Operation.FROM) 632 def dropna( 633 self, 634 how: str = "any", 635 thresh: t.Optional[int] = None, 636 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 637 ) -> DataFrame: 638 minimum_non_null = thresh or 0 # will be determined later if thresh is null 639 new_df = self.copy() 640 all_columns = self._get_outer_select_columns(new_df.expression) 641 if subset: 642 null_check_columns = self._ensure_and_normalize_cols(subset) 643 else: 644 null_check_columns = all_columns 645 if thresh is None: 646 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 647 else: 648 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 649 if minimum_num_nulls > len(null_check_columns): 650 raise RuntimeError( 651 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 652 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 653 ) 654 if_null_checks = [ 655 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 656 ] 657 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 658 num_nulls = nulls_added_together.alias("num_nulls") 659 new_df = new_df.select(num_nulls, append=True) 660 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 661 final_df = filtered_df.select(*all_columns) 662 return final_df
664 @operation(Operation.FROM) 665 def fillna( 666 self, 667 value: t.Union[ColumnLiterals], 668 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 669 ) -> DataFrame: 670 """ 671 Functionality Difference: If you provide a value to replace a null and that type conflicts 672 with the type of the column then PySpark will just ignore your replacement. 673 This will try to cast them to be the same in some cases. So they won't always match. 674 Best to not mix types so make sure replacement is the same type as the column 675 676 Possibility for improvement: Use `typeof` function to get the type of the column 677 and check if it matches the type of the value provided. If not then make it null. 678 """ 679 from sqlglot.dataframe.sql.functions import lit 680 681 values = None 682 columns = None 683 new_df = self.copy() 684 all_columns = self._get_outer_select_columns(new_df.expression) 685 all_column_mapping = {column.alias_or_name: column for column in all_columns} 686 if isinstance(value, dict): 687 values = list(value.values()) 688 columns = self._ensure_and_normalize_cols(list(value)) 689 if not columns: 690 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 691 if not values: 692 values = [value] * len(columns) 693 value_columns = [lit(value) for value in values] 694 695 null_replacement_mapping = { 696 column.alias_or_name: ( 697 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 698 ) 699 for column, value in zip(columns, value_columns) 700 } 701 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 702 null_replacement_columns = [ 703 null_replacement_mapping[column.alias_or_name] for column in all_columns 704 ] 705 new_df = new_df.select(*null_replacement_columns) 706 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
708 @operation(Operation.FROM) 709 def replace( 710 self, 711 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 712 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 713 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 714 ) -> DataFrame: 715 from sqlglot.dataframe.sql.functions import lit 716 717 old_values = None 718 new_df = self.copy() 719 all_columns = self._get_outer_select_columns(new_df.expression) 720 all_column_mapping = {column.alias_or_name: column for column in all_columns} 721 722 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 723 if isinstance(to_replace, dict): 724 old_values = list(to_replace) 725 new_values = list(to_replace.values()) 726 elif not old_values and isinstance(to_replace, list): 727 assert isinstance(value, list), "value must be a list since the replacements are a list" 728 assert len(to_replace) == len( 729 value 730 ), "the replacements and values must be the same length" 731 old_values = to_replace 732 new_values = value 733 else: 734 old_values = [to_replace] * len(columns) 735 new_values = [value] * len(columns) 736 old_values = [lit(value) for value in old_values] 737 new_values = [lit(value) for value in new_values] 738 739 replacement_mapping = {} 740 for column in columns: 741 expression = Column(None) 742 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 743 if i == 0: 744 expression = F.when(column == old_value, new_value) 745 else: 746 expression = expression.when(column == old_value, new_value) # type: ignore 747 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 748 column.expression.alias_or_name 749 ) 750 751 replacement_mapping = {**all_column_mapping, **replacement_mapping} 752 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 753 new_df = new_df.select(*replacement_columns) 754 return new_df
756 @operation(Operation.SELECT) 757 def withColumn(self, colName: str, col: Column) -> DataFrame: 758 col = self._ensure_and_normalize_col(col) 759 existing_col_names = self.expression.named_selects 760 existing_col_index = ( 761 existing_col_names.index(colName) if colName in existing_col_names else None 762 ) 763 if existing_col_index: 764 expression = self.expression.copy() 765 expression.expressions[existing_col_index] = col.expression 766 return self.copy(expression=expression) 767 return self.copy().select(col.alias(colName), append=True)
769 @operation(Operation.SELECT) 770 def withColumnRenamed(self, existing: str, new: str): 771 expression = self.expression.copy() 772 existing_columns = [ 773 expression 774 for expression in expression.expressions 775 if expression.alias_or_name == existing 776 ] 777 if not existing_columns: 778 raise ValueError("Tried to rename a column that doesn't exist") 779 for existing_column in existing_columns: 780 if isinstance(existing_column, exp.Column): 781 existing_column.replace(exp.alias_(existing_column, new)) 782 else: 783 existing_column.set("alias", exp.to_identifier(new)) 784 return self.copy(expression=expression)
786 @operation(Operation.SELECT) 787 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 788 all_columns = self._get_outer_select_columns(self.expression) 789 drop_cols = self._ensure_and_normalize_cols(cols) 790 new_columns = [ 791 col 792 for col in all_columns 793 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 794 ] 795 return self.copy().select(*new_columns, append=False)
801 @operation(Operation.NO_OP) 802 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 803 parameter_list = ensure_list(parameters) 804 parameter_columns = ( 805 self._ensure_list_of_columns(parameter_list) 806 if parameters 807 else Column.ensure_cols([self.sequence_id]) 808 ) 809 return self._hint(name, parameter_columns)
811 @operation(Operation.NO_OP) 812 def repartition( 813 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 814 ) -> DataFrame: 815 num_partition_cols = self._ensure_list_of_columns(numPartitions) 816 columns = self._ensure_and_normalize_cols(cols) 817 args = num_partition_cols + columns 818 return self._hint("repartition", args)
829 @operation(Operation.NO_OP) 830 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 831 """ 832 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 833 """ 834 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression) 164 165 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 166 return Column( 167 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 168 ) 169 170 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 171 return Column( 172 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 173 ) 174 175 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 176 return Column(klass(this=self.column_expression, **kwargs)) 177 178 @property 179 def is_alias(self): 180 return isinstance(self.expression, exp.Alias) 181 182 @property 183 def is_column(self): 184 return isinstance(self.expression, exp.Column) 185 186 @property 187 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 188 return self.expression.unalias() 189 190 @property 191 def alias_or_name(self) -> str: 192 return self.expression.alias_or_name 193 194 @classmethod 195 def ensure_literal(cls, value) -> Column: 196 from sqlglot.dataframe.sql.functions import lit 197 198 if isinstance(value, cls): 199 value = value.expression 200 if not isinstance(value, exp.Literal): 201 return lit(value) 202 return Column(value) 203 204 def copy(self) -> Column: 205 return Column(self.expression.copy()) 206 207 def set_table_name(self, table_name: str, copy=False) -> Column: 208 expression = self.expression.copy() if copy else self.expression 209 expression.set("table", exp.to_identifier(table_name)) 210 return Column(expression) 211 212 def sql(self, **kwargs) -> str: 213 from sqlglot.dataframe.sql.session import SparkSession 214 215 return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 216 217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression) 228 229 def asc(self) -> Column: 230 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 231 return Column(new_expression) 232 233 def desc(self) -> Column: 234 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 235 return Column(new_expression) 236 237 asc_nulls_first = asc 238 239 def asc_nulls_last(self) -> Column: 240 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 241 return Column(new_expression) 242 243 def desc_nulls_first(self) -> Column: 244 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 245 return Column(new_expression) 246 247 desc_nulls_last = desc 248 249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column 258 259 def otherwise(self, value: t.Any) -> Column: 260 from sqlglot.dataframe.sql.functions import lit 261 262 true_value = value if isinstance(value, Column) else lit(value) 263 new_column = self.copy() 264 new_column.expression.set("default", true_value.column_expression) 265 return new_column 266 267 def isNull(self) -> Column: 268 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 269 return Column(new_expression) 270 271 def isNotNull(self) -> Column: 272 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 273 return Column(new_expression) 274 275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) 285 286 def startswith(self, value: t.Union[str, Column]) -> Column: 287 value = self._lit(value) if not isinstance(value, Column) else value 288 return self.invoke_anonymous_function(self, "STARTSWITH", value) 289 290 def endswith(self, value: t.Union[str, Column]) -> Column: 291 value = self._lit(value) if not isinstance(value, Column) else value 292 return self.invoke_anonymous_function(self, "ENDSWITH", value) 293 294 def rlike(self, regexp: str) -> Column: 295 return self.invoke_expression_over_column( 296 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 297 ) 298 299 def like(self, other: str): 300 return self.invoke_expression_over_column( 301 self, exp.Like, expression=self._lit(other).expression 302 ) 303 304 def ilike(self, other: str): 305 return self.invoke_expression_over_column( 306 self, exp.ILike, expression=self._lit(other).expression 307 ) 308 309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 ) 315 316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 320 321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 ) 339 340 def over(self, window: WindowSpec) -> Column: 341 window_expression = window.expression.copy() 342 window_expression.set("this", self.column_expression) 343 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression)
217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression)
249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column
275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 )
316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 )
837class DataFrameNaFunctions: 838 def __init__(self, df: DataFrame): 839 self.df = df 840 841 def drop( 842 self, 843 how: str = "any", 844 thresh: t.Optional[int] = None, 845 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 846 ) -> DataFrame: 847 return self.df.dropna(how=how, thresh=thresh, subset=subset) 848 849 def fill( 850 self, 851 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 852 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 853 ) -> DataFrame: 854 return self.df.fillna(value=value, subset=subset) 855 856 def replace( 857 self, 858 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 859 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 860 subset: t.Optional[t.Union[str, t.List[str]]] = None, 861 ) -> DataFrame: 862 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
856 def replace( 857 self, 858 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 859 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 860 subset: t.Optional[t.Union[str, t.List[str]]] = None, 861 ) -> DataFrame: 862 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 from sqlglot.dataframe.sql.session import SparkSession 53 54 return self.expression.sql(dialect=SparkSession().dialect, **kwargs) 55 56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec 66 67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec 79 80 def _calc_start_end( 81 self, start: int, end: int 82 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 83 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 84 "start_side": None, 85 "end_side": None, 86 } 87 if start == Window.currentRow: 88 kwargs["start"] = "CURRENT ROW" 89 else: 90 kwargs = { 91 **kwargs, 92 **{ 93 "start_side": "PRECEDING", 94 "start": ( 95 "UNBOUNDED" 96 if start <= Window.unboundedPreceding 97 else F.lit(start).expression 98 ), 99 }, 100 } 101 if end == Window.currentRow: 102 kwargs["end"] = "CURRENT ROW" 103 else: 104 kwargs = { 105 **kwargs, 106 **{ 107 "end_side": "FOLLOWING", 108 "end": ( 109 "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression 110 ), 111 }, 112 } 113 return kwargs 114 115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec 126 127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec
67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec
115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec
127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
44class DataFrameWriter: 45 def __init__( 46 self, 47 df: DataFrame, 48 spark: t.Optional[SparkSession] = None, 49 mode: t.Optional[str] = None, 50 by_name: bool = False, 51 ): 52 self._df = df 53 self._spark = spark or df.spark 54 self._mode = mode 55 self._by_name = by_name 56 57 def copy(self, **kwargs) -> DataFrameWriter: 58 return DataFrameWriter( 59 **{ 60 k[1:] if k.startswith("_") else k: v 61 for k, v in object_to_dict(self, **kwargs).items() 62 } 63 ) 64 65 def sql(self, **kwargs) -> t.List[str]: 66 return self._df.sql(**kwargs) 67 68 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 69 return self.copy(_mode=saveMode) 70 71 @property 72 def byName(self): 73 return self.copy(by_name=True) 74 75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df) 92 93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df)
93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))