sqlglot.dataframe.sql
1from sqlglot.dataframe.sql.column import Column 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions 3from sqlglot.dataframe.sql.group import GroupedData 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter 5from sqlglot.dataframe.sql.session import SparkSession 6from sqlglot.dataframe.sql.window import Window, WindowSpec 7 8__all__ = [ 9 "SparkSession", 10 "DataFrame", 11 "GroupedData", 12 "Column", 13 "DataFrameNaFunctions", 14 "Window", 15 "WindowSpec", 16 "DataFrameReader", 17 "DataFrameWriter", 18]
23class SparkSession: 24 DEFAULT_DIALECT = "spark" 25 _instance = None 26 27 def __init__(self): 28 if not hasattr(self, "known_ids"): 29 self.known_ids = set() 30 self.known_branch_ids = set() 31 self.known_sequence_ids = set() 32 self.name_to_sequence_id_mapping = defaultdict(list) 33 self.incrementing_id = 1 34 self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) 35 36 def __new__(cls, *args, **kwargs) -> SparkSession: 37 if cls._instance is None: 38 cls._instance = super().__new__(cls) 39 return cls._instance 40 41 @property 42 def read(self) -> DataFrameReader: 43 return DataFrameReader(self) 44 45 def table(self, tableName: str) -> DataFrame: 46 return self.read.table(tableName) 47 48 def createDataFrame( 49 self, 50 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 51 schema: t.Optional[SchemaInput] = None, 52 samplingRatio: t.Optional[float] = None, 53 verifySchema: bool = False, 54 ) -> DataFrame: 55 from sqlglot.dataframe.sql.dataframe import DataFrame 56 57 if samplingRatio is not None or verifySchema: 58 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 59 if schema is not None and ( 60 not isinstance(schema, (StructType, str, list)) 61 or (isinstance(schema, list) and not isinstance(schema[0], str)) 62 ): 63 raise NotImplementedError("Only schema of either list or string of list supported") 64 if not data: 65 raise ValueError("Must provide data to create into a DataFrame") 66 67 column_mapping: t.Dict[str, t.Optional[str]] 68 if schema is not None: 69 column_mapping = get_column_mapping_from_schema_input(schema) 70 elif isinstance(data[0], dict): 71 column_mapping = {col_name.strip(): None for col_name in data[0]} 72 else: 73 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 74 75 data_expressions = [ 76 exp.tuple_( 77 *map( 78 lambda x: F.lit(x).expression, 79 row if not isinstance(row, dict) else row.values(), 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 ( 87 F.col(name).cast(data_type).alias(name).expression 88 if data_type is not None 89 else F.col(name).expression 90 ) 91 for name, data_type in column_mapping.items() 92 ] 93 94 select_kwargs = { 95 "expressions": sel_columns, 96 "from": exp.From( 97 this=exp.Values( 98 expressions=data_expressions, 99 alias=exp.TableAlias( 100 this=exp.to_identifier(self._auto_incrementing_name), 101 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 102 ), 103 ), 104 ), 105 } 106 107 sel_expression = exp.Select(**select_kwargs) 108 return DataFrame(self, sel_expression) 109 110 def _optimize( 111 self, expression: exp.Expression, dialect: t.Optional[Dialect] = None 112 ) -> exp.Expression: 113 dialect = dialect or self.dialect 114 quote_identifiers(expression, dialect=dialect) 115 return optimize(expression, dialect=dialect) 116 117 def sql(self, sqlQuery: str) -> DataFrame: 118 expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) 119 if isinstance(expression, exp.Select): 120 df = DataFrame(self, expression) 121 df = df._convert_leaf_to_cte() 122 elif isinstance(expression, (exp.Create, exp.Insert)): 123 select_expression = expression.expression.copy() 124 if isinstance(expression, exp.Insert): 125 select_expression.set("with", expression.args.get("with")) 126 expression.set("with", None) 127 del expression.args["expression"] 128 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 129 df = df._convert_leaf_to_cte() 130 else: 131 raise ValueError( 132 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 133 ) 134 return df 135 136 @property 137 def _auto_incrementing_name(self) -> str: 138 name = f"a{self.incrementing_id}" 139 self.incrementing_id += 1 140 return name 141 142 @property 143 def _random_branch_id(self) -> str: 144 id = self._random_id 145 self.known_branch_ids.add(id) 146 return id 147 148 @property 149 def _random_sequence_id(self): 150 id = self._random_id 151 self.known_sequence_ids.add(id) 152 return id 153 154 @property 155 def _random_id(self) -> str: 156 id = "r" + uuid.uuid4().hex 157 self.known_ids.add(id) 158 return id 159 160 @property 161 def _join_hint_names(self) -> t.Set[str]: 162 return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} 163 164 def _add_alias_to_mapping(self, name: str, sequence_id: str): 165 self.name_to_sequence_id_mapping[name].append(sequence_id) 166 167 class Builder: 168 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 169 170 def __init__(self): 171 self.dialect = "spark" 172 173 def __getattr__(self, item) -> SparkSession.Builder: 174 return self 175 176 def __call__(self, *args, **kwargs): 177 return self 178 179 def config( 180 self, 181 key: t.Optional[str] = None, 182 value: t.Optional[t.Any] = None, 183 *, 184 map: t.Optional[t.Dict[str, t.Any]] = None, 185 **kwargs: t.Any, 186 ) -> SparkSession.Builder: 187 if key == self.SQLFRAME_DIALECT_KEY: 188 self.dialect = value 189 elif map and self.SQLFRAME_DIALECT_KEY in map: 190 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 191 return self 192 193 def getOrCreate(self) -> SparkSession: 194 spark = SparkSession() 195 spark.dialect = Dialect.get_or_raise(self.dialect) 196 return spark 197 198 @classproperty 199 def builder(cls) -> Builder: 200 return cls.Builder()
48 def createDataFrame( 49 self, 50 data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], 51 schema: t.Optional[SchemaInput] = None, 52 samplingRatio: t.Optional[float] = None, 53 verifySchema: bool = False, 54 ) -> DataFrame: 55 from sqlglot.dataframe.sql.dataframe import DataFrame 56 57 if samplingRatio is not None or verifySchema: 58 raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") 59 if schema is not None and ( 60 not isinstance(schema, (StructType, str, list)) 61 or (isinstance(schema, list) and not isinstance(schema[0], str)) 62 ): 63 raise NotImplementedError("Only schema of either list or string of list supported") 64 if not data: 65 raise ValueError("Must provide data to create into a DataFrame") 66 67 column_mapping: t.Dict[str, t.Optional[str]] 68 if schema is not None: 69 column_mapping = get_column_mapping_from_schema_input(schema) 70 elif isinstance(data[0], dict): 71 column_mapping = {col_name.strip(): None for col_name in data[0]} 72 else: 73 column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} 74 75 data_expressions = [ 76 exp.tuple_( 77 *map( 78 lambda x: F.lit(x).expression, 79 row if not isinstance(row, dict) else row.values(), 80 ) 81 ) 82 for row in data 83 ] 84 85 sel_columns = [ 86 ( 87 F.col(name).cast(data_type).alias(name).expression 88 if data_type is not None 89 else F.col(name).expression 90 ) 91 for name, data_type in column_mapping.items() 92 ] 93 94 select_kwargs = { 95 "expressions": sel_columns, 96 "from": exp.From( 97 this=exp.Values( 98 expressions=data_expressions, 99 alias=exp.TableAlias( 100 this=exp.to_identifier(self._auto_incrementing_name), 101 columns=[exp.to_identifier(col_name) for col_name in column_mapping], 102 ), 103 ), 104 ), 105 } 106 107 sel_expression = exp.Select(**select_kwargs) 108 return DataFrame(self, sel_expression)
117 def sql(self, sqlQuery: str) -> DataFrame: 118 expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) 119 if isinstance(expression, exp.Select): 120 df = DataFrame(self, expression) 121 df = df._convert_leaf_to_cte() 122 elif isinstance(expression, (exp.Create, exp.Insert)): 123 select_expression = expression.expression.copy() 124 if isinstance(expression, exp.Insert): 125 select_expression.set("with", expression.args.get("with")) 126 expression.set("with", None) 127 del expression.args["expression"] 128 df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore 129 df = df._convert_leaf_to_cte() 130 else: 131 raise ValueError( 132 "Unknown expression type provided in the SQL. Please create an issue with the SQL." 133 ) 134 return df
167 class Builder: 168 SQLFRAME_DIALECT_KEY = "sqlframe.dialect" 169 170 def __init__(self): 171 self.dialect = "spark" 172 173 def __getattr__(self, item) -> SparkSession.Builder: 174 return self 175 176 def __call__(self, *args, **kwargs): 177 return self 178 179 def config( 180 self, 181 key: t.Optional[str] = None, 182 value: t.Optional[t.Any] = None, 183 *, 184 map: t.Optional[t.Dict[str, t.Any]] = None, 185 **kwargs: t.Any, 186 ) -> SparkSession.Builder: 187 if key == self.SQLFRAME_DIALECT_KEY: 188 self.dialect = value 189 elif map and self.SQLFRAME_DIALECT_KEY in map: 190 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 191 return self 192 193 def getOrCreate(self) -> SparkSession: 194 spark = SparkSession() 195 spark.dialect = Dialect.get_or_raise(self.dialect) 196 return spark
179 def config( 180 self, 181 key: t.Optional[str] = None, 182 value: t.Optional[t.Any] = None, 183 *, 184 map: t.Optional[t.Dict[str, t.Any]] = None, 185 **kwargs: t.Any, 186 ) -> SparkSession.Builder: 187 if key == self.SQLFRAME_DIALECT_KEY: 188 self.dialect = value 189 elif map and self.SQLFRAME_DIALECT_KEY in map: 190 self.dialect = map[self.SQLFRAME_DIALECT_KEY] 191 return self
47class DataFrame: 48 def __init__( 49 self, 50 spark: SparkSession, 51 expression: exp.Select, 52 branch_id: t.Optional[str] = None, 53 sequence_id: t.Optional[str] = None, 54 last_op: Operation = Operation.INIT, 55 pending_hints: t.Optional[t.List[exp.Expression]] = None, 56 output_expression_container: t.Optional[OutputExpressionContainer] = None, 57 **kwargs, 58 ): 59 self.spark = spark 60 self.expression = expression 61 self.branch_id = branch_id or self.spark._random_branch_id 62 self.sequence_id = sequence_id or self.spark._random_sequence_id 63 self.last_op = last_op 64 self.pending_hints = pending_hints or [] 65 self.output_expression_container = output_expression_container or exp.Select() 66 67 def __getattr__(self, column_name: str) -> Column: 68 return self[column_name] 69 70 def __getitem__(self, column_name: str) -> Column: 71 column_name = f"{self.branch_id}.{column_name}" 72 return Column(column_name) 73 74 def __copy__(self): 75 return self.copy() 76 77 @property 78 def sparkSession(self): 79 return self.spark 80 81 @property 82 def write(self): 83 return DataFrameWriter(self) 84 85 @property 86 def latest_cte_name(self) -> str: 87 if not self.expression.ctes: 88 from_exp = self.expression.args["from"] 89 if from_exp.alias_or_name: 90 return from_exp.alias_or_name 91 table_alias = from_exp.find(exp.TableAlias) 92 if not table_alias: 93 raise RuntimeError( 94 f"Could not find an alias name for this expression: {self.expression}" 95 ) 96 return table_alias.alias_or_name 97 return self.expression.ctes[-1].alias 98 99 @property 100 def pending_join_hints(self): 101 return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] 102 103 @property 104 def pending_partition_hints(self): 105 return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] 106 107 @property 108 def columns(self) -> t.List[str]: 109 return self.expression.named_selects 110 111 @property 112 def na(self) -> DataFrameNaFunctions: 113 return DataFrameNaFunctions(self) 114 115 def _replace_cte_names_with_hashes(self, expression: exp.Select): 116 replacement_mapping = {} 117 for cte in expression.ctes: 118 old_name_id = cte.args["alias"].this 119 new_hashed_id = exp.to_identifier( 120 self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] 121 ) 122 replacement_mapping[old_name_id] = new_hashed_id 123 expression = expression.transform(replace_id_value, replacement_mapping) 124 return expression 125 126 def _create_cte_from_expression( 127 self, 128 expression: exp.Expression, 129 branch_id: t.Optional[str] = None, 130 sequence_id: t.Optional[str] = None, 131 **kwargs, 132 ) -> t.Tuple[exp.CTE, str]: 133 name = self._create_hash_from_expression(expression) 134 expression_to_cte = expression.copy() 135 expression_to_cte.set("with", None) 136 cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] 137 cte.set("branch_id", branch_id or self.branch_id) 138 cte.set("sequence_id", sequence_id or self.sequence_id) 139 return cte, name 140 141 @t.overload 142 def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... 143 144 @t.overload 145 def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... 146 147 def _ensure_list_of_columns(self, cols): 148 return Column.ensure_cols(ensure_list(cols)) 149 150 def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): 151 cols = self._ensure_list_of_columns(cols) 152 normalize(self.spark, expression or self.expression, cols) 153 return cols 154 155 def _ensure_and_normalize_col(self, col): 156 col = Column.ensure_col(col) 157 normalize(self.spark, self.expression, col) 158 return col 159 160 def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: 161 df = self._resolve_pending_hints() 162 sequence_id = sequence_id or df.sequence_id 163 expression = df.expression.copy() 164 cte_expression, cte_name = df._create_cte_from_expression( 165 expression=expression, sequence_id=sequence_id 166 ) 167 new_expression = df._add_ctes_to_expression( 168 exp.Select(), expression.ctes + [cte_expression] 169 ) 170 sel_columns = df._get_outer_select_columns(cte_expression) 171 new_expression = new_expression.from_(cte_name).select( 172 *[x.alias_or_name for x in sel_columns] 173 ) 174 return df.copy(expression=new_expression, sequence_id=sequence_id) 175 176 def _resolve_pending_hints(self) -> DataFrame: 177 df = self.copy() 178 if not self.pending_hints: 179 return df 180 expression = df.expression 181 hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) 182 for hint in df.pending_partition_hints: 183 hint_expression.append("expressions", hint) 184 df.pending_hints.remove(hint) 185 186 join_aliases = { 187 join_table.alias_or_name 188 for join_table in get_tables_from_expression_with_join(expression) 189 } 190 if join_aliases: 191 for hint in df.pending_join_hints: 192 for sequence_id_expression in hint.expressions: 193 sequence_id_or_name = sequence_id_expression.alias_or_name 194 sequence_ids_to_match = [sequence_id_or_name] 195 if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: 196 sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ 197 sequence_id_or_name 198 ] 199 matching_ctes = [ 200 cte 201 for cte in reversed(expression.ctes) 202 if cte.args["sequence_id"] in sequence_ids_to_match 203 ] 204 for matching_cte in matching_ctes: 205 if matching_cte.alias_or_name in join_aliases: 206 sequence_id_expression.set("this", matching_cte.args["alias"].this) 207 df.pending_hints.remove(hint) 208 break 209 hint_expression.append("expressions", hint) 210 if hint_expression.expressions: 211 expression.set("hint", hint_expression) 212 return df 213 214 def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: 215 hint_name = hint_name.upper() 216 hint_expression = ( 217 exp.JoinHint( 218 this=hint_name, 219 expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], 220 ) 221 if hint_name in JOIN_HINTS 222 else exp.Anonymous( 223 this=hint_name, expressions=[parameter.expression for parameter in args] 224 ) 225 ) 226 new_df = self.copy() 227 new_df.pending_hints.append(hint_expression) 228 return new_df 229 230 def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): 231 other_df = other._convert_leaf_to_cte() 232 base_expression = self.expression.copy() 233 base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) 234 all_ctes = base_expression.ctes 235 other_df.expression.set("with", None) 236 base_expression.set("with", None) 237 operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) 238 operation.set("with", exp.With(expressions=all_ctes)) 239 return self.copy(expression=operation)._convert_leaf_to_cte() 240 241 def _cache(self, storage_level: str): 242 df = self._convert_leaf_to_cte() 243 df.expression.ctes[-1].set("cache_storage_level", storage_level) 244 return df 245 246 @classmethod 247 def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: 248 expression = expression.copy() 249 with_expression = expression.args.get("with") 250 if with_expression: 251 existing_ctes = with_expression.expressions 252 existsing_cte_names = {x.alias_or_name for x in existing_ctes} 253 for cte in ctes: 254 if cte.alias_or_name not in existsing_cte_names: 255 existing_ctes.append(cte) 256 else: 257 existing_ctes = ctes 258 expression.set("with", exp.With(expressions=existing_ctes)) 259 return expression 260 261 @classmethod 262 def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: 263 expression = item.expression if isinstance(item, DataFrame) else item 264 return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] 265 266 @classmethod 267 def _create_hash_from_expression(cls, expression: exp.Expression) -> str: 268 from sqlglot.dataframe.sql.session import SparkSession 269 270 value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") 271 return f"t{zlib.crc32(value)}"[:6] 272 273 def _get_select_expressions( 274 self, 275 ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: 276 select_expressions: t.List[ 277 t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] 278 ] = [] 279 main_select_ctes: t.List[exp.CTE] = [] 280 for cte in self.expression.ctes: 281 cache_storage_level = cte.args.get("cache_storage_level") 282 if cache_storage_level: 283 select_expression = cte.this.copy() 284 select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) 285 select_expression.set("cte_alias_name", cte.alias_or_name) 286 select_expression.set("cache_storage_level", cache_storage_level) 287 select_expressions.append((exp.Cache, select_expression)) 288 else: 289 main_select_ctes.append(cte) 290 main_select = self.expression.copy() 291 if main_select_ctes: 292 main_select.set("with", exp.With(expressions=main_select_ctes)) 293 expression_select_pair = (type(self.output_expression_container), main_select) 294 select_expressions.append(expression_select_pair) # type: ignore 295 return select_expressions 296 297 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 298 from sqlglot.dataframe.sql.session import SparkSession 299 300 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 301 302 df = self._resolve_pending_hints() 303 select_expressions = df._get_select_expressions() 304 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 305 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 306 307 for expression_type, select_expression in select_expressions: 308 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 309 if optimize: 310 select_expression = t.cast( 311 exp.Select, self.spark._optimize(select_expression, dialect=dialect) 312 ) 313 314 select_expression = df._replace_cte_names_with_hashes(select_expression) 315 316 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 317 if expression_type == exp.Cache: 318 cache_table_name = df._create_hash_from_expression(select_expression) 319 cache_table = exp.to_table(cache_table_name) 320 original_alias_name = select_expression.args["cte_alias_name"] 321 322 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 323 cache_table_name 324 ) 325 sqlglot.schema.add_table( 326 cache_table_name, 327 { 328 expression.alias_or_name: expression.type.sql(dialect=dialect) 329 for expression in select_expression.expressions 330 }, 331 dialect=dialect, 332 ) 333 334 cache_storage_level = select_expression.args["cache_storage_level"] 335 options = [ 336 exp.Literal.string("storageLevel"), 337 exp.Literal.string(cache_storage_level), 338 ] 339 expression = exp.Cache( 340 this=cache_table, expression=select_expression, lazy=True, options=options 341 ) 342 343 # We will drop the "view" if it exists before running the cache table 344 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 345 elif expression_type == exp.Create: 346 expression = df.output_expression_container.copy() 347 expression.set("expression", select_expression) 348 elif expression_type == exp.Insert: 349 expression = df.output_expression_container.copy() 350 select_without_ctes = select_expression.copy() 351 select_without_ctes.set("with", None) 352 expression.set("expression", select_without_ctes) 353 354 if select_expression.ctes: 355 expression.set("with", exp.With(expressions=select_expression.ctes)) 356 elif expression_type == exp.Select: 357 expression = select_expression 358 else: 359 raise ValueError(f"Invalid expression type: {expression_type}") 360 361 output_expressions.append(expression) 362 363 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] 364 365 def copy(self, **kwargs) -> DataFrame: 366 return DataFrame(**object_to_dict(self, **kwargs)) 367 368 @operation(Operation.SELECT) 369 def select(self, *cols, **kwargs) -> DataFrame: 370 cols = self._ensure_and_normalize_cols(cols) 371 kwargs["append"] = kwargs.get("append", False) 372 if self.expression.args.get("joins"): 373 ambiguous_cols = [ 374 col 375 for col in cols 376 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 377 ] 378 if ambiguous_cols: 379 join_table_identifiers = [ 380 x.this for x in get_tables_from_expression_with_join(self.expression) 381 ] 382 cte_names_in_join = [x.this for x in join_table_identifiers] 383 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 384 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 385 # of Spark. 386 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 387 for ambiguous_col in ambiguous_cols: 388 ctes_with_column = [ 389 cte 390 for cte in self.expression.ctes 391 if cte.alias_or_name in cte_names_in_join 392 and ambiguous_col.alias_or_name in cte.this.named_selects 393 ] 394 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 395 # use the same CTE we used before 396 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 397 if cte: 398 resolved_column_position[ambiguous_col] += 1 399 else: 400 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 401 ambiguous_col.expression.set("table", cte.alias_or_name) 402 return self.copy( 403 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 404 ) 405 406 @operation(Operation.NO_OP) 407 def alias(self, name: str, **kwargs) -> DataFrame: 408 new_sequence_id = self.spark._random_sequence_id 409 df = self.copy() 410 for join_hint in df.pending_join_hints: 411 for expression in join_hint.expressions: 412 if expression.alias_or_name == self.sequence_id: 413 expression.set("this", Column.ensure_col(new_sequence_id).expression) 414 df.spark._add_alias_to_mapping(name, new_sequence_id) 415 return df._convert_leaf_to_cte(sequence_id=new_sequence_id) 416 417 @operation(Operation.WHERE) 418 def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: 419 col = self._ensure_and_normalize_col(column) 420 return self.copy(expression=self.expression.where(col.expression)) 421 422 filter = where 423 424 @operation(Operation.GROUP_BY) 425 def groupBy(self, *cols, **kwargs) -> GroupedData: 426 columns = self._ensure_and_normalize_cols(cols) 427 return GroupedData(self, columns, self.last_op) 428 429 @operation(Operation.SELECT) 430 def agg(self, *exprs, **kwargs) -> DataFrame: 431 cols = self._ensure_and_normalize_cols(exprs) 432 return self.groupBy().agg(*cols) 433 434 @operation(Operation.FROM) 435 def join( 436 self, 437 other_df: DataFrame, 438 on: t.Union[str, t.List[str], Column, t.List[Column]], 439 how: str = "inner", 440 **kwargs, 441 ) -> DataFrame: 442 other_df = other_df._convert_leaf_to_cte() 443 join_columns = self._ensure_list_of_columns(on) 444 # We will determine actual "join on" expression later so we don't provide it at first 445 join_expression = self.expression.join( 446 other_df.latest_cte_name, join_type=how.replace("_", " ") 447 ) 448 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 449 self_columns = self._get_outer_select_columns(join_expression) 450 other_columns = self._get_outer_select_columns(other_df) 451 # Determines the join clause and select columns to be used passed on what type of columns were provided for 452 # the join. The columns returned changes based on how the on expression is provided. 453 if isinstance(join_columns[0].expression, exp.Column): 454 """ 455 Unique characteristics of join on column names only: 456 * The column names are put at the front of the select list 457 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 458 """ 459 table_names = [ 460 table.alias_or_name 461 for table in get_tables_from_expression_with_join(join_expression) 462 ] 463 potential_ctes = [ 464 cte 465 for cte in join_expression.ctes 466 if cte.alias_or_name in table_names 467 and cte.alias_or_name != other_df.latest_cte_name 468 ] 469 # Determine the table to reference for the left side of the join by checking each of the left side 470 # tables and see if they have the column being referenced. 471 join_column_pairs = [] 472 for join_column in join_columns: 473 num_matching_ctes = 0 474 for cte in potential_ctes: 475 if join_column.alias_or_name in cte.this.named_selects: 476 left_column = join_column.copy().set_table_name(cte.alias_or_name) 477 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 478 join_column_pairs.append((left_column, right_column)) 479 num_matching_ctes += 1 480 if num_matching_ctes > 1: 481 raise ValueError( 482 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 483 ) 484 elif num_matching_ctes == 0: 485 raise ValueError( 486 f"Column {join_column.alias_or_name} does not exist in any of the tables." 487 ) 488 join_clause = functools.reduce( 489 lambda x, y: x & y, 490 [left_column == right_column for left_column, right_column in join_column_pairs], 491 ) 492 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 493 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 494 select_column_names = [ 495 ( 496 column.alias_or_name 497 if not isinstance(column.expression.this, exp.Star) 498 else column.sql() 499 ) 500 for column in self_columns + other_columns 501 ] 502 select_column_names = [ 503 column_name 504 for column_name in select_column_names 505 if column_name not in join_column_names 506 ] 507 select_column_names = join_column_names + select_column_names 508 else: 509 """ 510 Unique characteristics of join on expressions: 511 * There is no deduplication of the results. 512 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 513 """ 514 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 515 if len(join_columns) > 1: 516 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 517 join_clause = join_columns[0] 518 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 519 520 # Update the on expression with the actual join clause to replace the dummy one from before 521 join_expression.args["joins"][-1].set("on", join_clause.expression) 522 new_df = self.copy(expression=join_expression) 523 new_df.pending_join_hints.extend(self.pending_join_hints) 524 new_df.pending_hints.extend(other_df.pending_hints) 525 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 526 return new_df 527 528 @operation(Operation.ORDER_BY) 529 def orderBy( 530 self, 531 *cols: t.Union[str, Column], 532 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 533 ) -> DataFrame: 534 """ 535 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 536 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 537 is unlikely to come up. 538 """ 539 columns = self._ensure_and_normalize_cols(cols) 540 pre_ordered_col_indexes = [ 541 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 542 ] 543 if ascending is None: 544 ascending = [True] * len(columns) 545 elif not isinstance(ascending, list): 546 ascending = [ascending] * len(columns) 547 ascending = [bool(x) for i, x in enumerate(ascending)] 548 assert len(columns) == len( 549 ascending 550 ), "The length of items in ascending must equal the number of columns provided" 551 col_and_ascending = list(zip(columns, ascending)) 552 order_by_columns = [ 553 ( 554 exp.Ordered(this=col.expression, desc=not asc) 555 if i not in pre_ordered_col_indexes 556 else columns[i].column_expression 557 ) 558 for i, (col, asc) in enumerate(col_and_ascending) 559 ] 560 return self.copy(expression=self.expression.order_by(*order_by_columns)) 561 562 sort = orderBy 563 564 @operation(Operation.FROM) 565 def union(self, other: DataFrame) -> DataFrame: 566 return self._set_operation(exp.Union, other, False) 567 568 unionAll = union 569 570 @operation(Operation.FROM) 571 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 572 l_columns = self.columns 573 r_columns = other.columns 574 if not allowMissingColumns: 575 l_expressions = l_columns 576 r_expressions = l_columns 577 else: 578 l_expressions = [] 579 r_expressions = [] 580 r_columns_unused = copy(r_columns) 581 for l_column in l_columns: 582 l_expressions.append(l_column) 583 if l_column in r_columns: 584 r_expressions.append(l_column) 585 r_columns_unused.remove(l_column) 586 else: 587 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 588 for r_column in r_columns_unused: 589 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 590 r_expressions.append(r_column) 591 r_df = ( 592 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 593 ) 594 l_df = self.copy() 595 if allowMissingColumns: 596 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 597 return l_df._set_operation(exp.Union, r_df, False) 598 599 @operation(Operation.FROM) 600 def intersect(self, other: DataFrame) -> DataFrame: 601 return self._set_operation(exp.Intersect, other, True) 602 603 @operation(Operation.FROM) 604 def intersectAll(self, other: DataFrame) -> DataFrame: 605 return self._set_operation(exp.Intersect, other, False) 606 607 @operation(Operation.FROM) 608 def exceptAll(self, other: DataFrame) -> DataFrame: 609 return self._set_operation(exp.Except, other, False) 610 611 @operation(Operation.SELECT) 612 def distinct(self) -> DataFrame: 613 return self.copy(expression=self.expression.distinct()) 614 615 @operation(Operation.SELECT) 616 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 617 if not subset: 618 return self.distinct() 619 column_names = ensure_list(subset) 620 window = Window.partitionBy(*column_names).orderBy(*column_names) 621 return ( 622 self.copy() 623 .withColumn("row_num", F.row_number().over(window)) 624 .where(F.col("row_num") == F.lit(1)) 625 .drop("row_num") 626 ) 627 628 @operation(Operation.FROM) 629 def dropna( 630 self, 631 how: str = "any", 632 thresh: t.Optional[int] = None, 633 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 634 ) -> DataFrame: 635 minimum_non_null = thresh or 0 # will be determined later if thresh is null 636 new_df = self.copy() 637 all_columns = self._get_outer_select_columns(new_df.expression) 638 if subset: 639 null_check_columns = self._ensure_and_normalize_cols(subset) 640 else: 641 null_check_columns = all_columns 642 if thresh is None: 643 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 644 else: 645 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 646 if minimum_num_nulls > len(null_check_columns): 647 raise RuntimeError( 648 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 649 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 650 ) 651 if_null_checks = [ 652 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 653 ] 654 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 655 num_nulls = nulls_added_together.alias("num_nulls") 656 new_df = new_df.select(num_nulls, append=True) 657 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 658 final_df = filtered_df.select(*all_columns) 659 return final_df 660 661 @operation(Operation.FROM) 662 def fillna( 663 self, 664 value: t.Union[ColumnLiterals], 665 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 666 ) -> DataFrame: 667 """ 668 Functionality Difference: If you provide a value to replace a null and that type conflicts 669 with the type of the column then PySpark will just ignore your replacement. 670 This will try to cast them to be the same in some cases. So they won't always match. 671 Best to not mix types so make sure replacement is the same type as the column 672 673 Possibility for improvement: Use `typeof` function to get the type of the column 674 and check if it matches the type of the value provided. If not then make it null. 675 """ 676 from sqlglot.dataframe.sql.functions import lit 677 678 values = None 679 columns = None 680 new_df = self.copy() 681 all_columns = self._get_outer_select_columns(new_df.expression) 682 all_column_mapping = {column.alias_or_name: column for column in all_columns} 683 if isinstance(value, dict): 684 values = list(value.values()) 685 columns = self._ensure_and_normalize_cols(list(value)) 686 if not columns: 687 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 688 if not values: 689 values = [value] * len(columns) 690 value_columns = [lit(value) for value in values] 691 692 null_replacement_mapping = { 693 column.alias_or_name: ( 694 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 695 ) 696 for column, value in zip(columns, value_columns) 697 } 698 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 699 null_replacement_columns = [ 700 null_replacement_mapping[column.alias_or_name] for column in all_columns 701 ] 702 new_df = new_df.select(*null_replacement_columns) 703 return new_df 704 705 @operation(Operation.FROM) 706 def replace( 707 self, 708 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 709 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 710 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 711 ) -> DataFrame: 712 from sqlglot.dataframe.sql.functions import lit 713 714 old_values = None 715 new_df = self.copy() 716 all_columns = self._get_outer_select_columns(new_df.expression) 717 all_column_mapping = {column.alias_or_name: column for column in all_columns} 718 719 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 720 if isinstance(to_replace, dict): 721 old_values = list(to_replace) 722 new_values = list(to_replace.values()) 723 elif not old_values and isinstance(to_replace, list): 724 assert isinstance(value, list), "value must be a list since the replacements are a list" 725 assert len(to_replace) == len( 726 value 727 ), "the replacements and values must be the same length" 728 old_values = to_replace 729 new_values = value 730 else: 731 old_values = [to_replace] * len(columns) 732 new_values = [value] * len(columns) 733 old_values = [lit(value) for value in old_values] 734 new_values = [lit(value) for value in new_values] 735 736 replacement_mapping = {} 737 for column in columns: 738 expression = Column(None) 739 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 740 if i == 0: 741 expression = F.when(column == old_value, new_value) 742 else: 743 expression = expression.when(column == old_value, new_value) # type: ignore 744 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 745 column.expression.alias_or_name 746 ) 747 748 replacement_mapping = {**all_column_mapping, **replacement_mapping} 749 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 750 new_df = new_df.select(*replacement_columns) 751 return new_df 752 753 @operation(Operation.SELECT) 754 def withColumn(self, colName: str, col: Column) -> DataFrame: 755 col = self._ensure_and_normalize_col(col) 756 existing_col_names = self.expression.named_selects 757 existing_col_index = ( 758 existing_col_names.index(colName) if colName in existing_col_names else None 759 ) 760 if existing_col_index: 761 expression = self.expression.copy() 762 expression.expressions[existing_col_index] = col.expression 763 return self.copy(expression=expression) 764 return self.copy().select(col.alias(colName), append=True) 765 766 @operation(Operation.SELECT) 767 def withColumnRenamed(self, existing: str, new: str): 768 expression = self.expression.copy() 769 existing_columns = [ 770 expression 771 for expression in expression.expressions 772 if expression.alias_or_name == existing 773 ] 774 if not existing_columns: 775 raise ValueError("Tried to rename a column that doesn't exist") 776 for existing_column in existing_columns: 777 if isinstance(existing_column, exp.Column): 778 existing_column.replace(exp.alias_(existing_column, new)) 779 else: 780 existing_column.set("alias", exp.to_identifier(new)) 781 return self.copy(expression=expression) 782 783 @operation(Operation.SELECT) 784 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 785 all_columns = self._get_outer_select_columns(self.expression) 786 drop_cols = self._ensure_and_normalize_cols(cols) 787 new_columns = [ 788 col 789 for col in all_columns 790 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 791 ] 792 return self.copy().select(*new_columns, append=False) 793 794 @operation(Operation.LIMIT) 795 def limit(self, num: int) -> DataFrame: 796 return self.copy(expression=self.expression.limit(num)) 797 798 @operation(Operation.NO_OP) 799 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 800 parameter_list = ensure_list(parameters) 801 parameter_columns = ( 802 self._ensure_list_of_columns(parameter_list) 803 if parameters 804 else Column.ensure_cols([self.sequence_id]) 805 ) 806 return self._hint(name, parameter_columns) 807 808 @operation(Operation.NO_OP) 809 def repartition( 810 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 811 ) -> DataFrame: 812 num_partition_cols = self._ensure_list_of_columns(numPartitions) 813 columns = self._ensure_and_normalize_cols(cols) 814 args = num_partition_cols + columns 815 return self._hint("repartition", args) 816 817 @operation(Operation.NO_OP) 818 def coalesce(self, numPartitions: int) -> DataFrame: 819 num_partitions = Column.ensure_cols([numPartitions]) 820 return self._hint("coalesce", num_partitions) 821 822 @operation(Operation.NO_OP) 823 def cache(self) -> DataFrame: 824 return self._cache(storage_level="MEMORY_AND_DISK") 825 826 @operation(Operation.NO_OP) 827 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 828 """ 829 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 830 """ 831 return self._cache(storageLevel)
48 def __init__( 49 self, 50 spark: SparkSession, 51 expression: exp.Select, 52 branch_id: t.Optional[str] = None, 53 sequence_id: t.Optional[str] = None, 54 last_op: Operation = Operation.INIT, 55 pending_hints: t.Optional[t.List[exp.Expression]] = None, 56 output_expression_container: t.Optional[OutputExpressionContainer] = None, 57 **kwargs, 58 ): 59 self.spark = spark 60 self.expression = expression 61 self.branch_id = branch_id or self.spark._random_branch_id 62 self.sequence_id = sequence_id or self.spark._random_sequence_id 63 self.last_op = last_op 64 self.pending_hints = pending_hints or [] 65 self.output_expression_container = output_expression_container or exp.Select()
85 @property 86 def latest_cte_name(self) -> str: 87 if not self.expression.ctes: 88 from_exp = self.expression.args["from"] 89 if from_exp.alias_or_name: 90 return from_exp.alias_or_name 91 table_alias = from_exp.find(exp.TableAlias) 92 if not table_alias: 93 raise RuntimeError( 94 f"Could not find an alias name for this expression: {self.expression}" 95 ) 96 return table_alias.alias_or_name 97 return self.expression.ctes[-1].alias
297 def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: 298 from sqlglot.dataframe.sql.session import SparkSession 299 300 dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) 301 302 df = self._resolve_pending_hints() 303 select_expressions = df._get_select_expressions() 304 output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] 305 replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} 306 307 for expression_type, select_expression in select_expressions: 308 select_expression = select_expression.transform(replace_id_value, replacement_mapping) 309 if optimize: 310 select_expression = t.cast( 311 exp.Select, self.spark._optimize(select_expression, dialect=dialect) 312 ) 313 314 select_expression = df._replace_cte_names_with_hashes(select_expression) 315 316 expression: t.Union[exp.Select, exp.Cache, exp.Drop] 317 if expression_type == exp.Cache: 318 cache_table_name = df._create_hash_from_expression(select_expression) 319 cache_table = exp.to_table(cache_table_name) 320 original_alias_name = select_expression.args["cte_alias_name"] 321 322 replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore 323 cache_table_name 324 ) 325 sqlglot.schema.add_table( 326 cache_table_name, 327 { 328 expression.alias_or_name: expression.type.sql(dialect=dialect) 329 for expression in select_expression.expressions 330 }, 331 dialect=dialect, 332 ) 333 334 cache_storage_level = select_expression.args["cache_storage_level"] 335 options = [ 336 exp.Literal.string("storageLevel"), 337 exp.Literal.string(cache_storage_level), 338 ] 339 expression = exp.Cache( 340 this=cache_table, expression=select_expression, lazy=True, options=options 341 ) 342 343 # We will drop the "view" if it exists before running the cache table 344 output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) 345 elif expression_type == exp.Create: 346 expression = df.output_expression_container.copy() 347 expression.set("expression", select_expression) 348 elif expression_type == exp.Insert: 349 expression = df.output_expression_container.copy() 350 select_without_ctes = select_expression.copy() 351 select_without_ctes.set("with", None) 352 expression.set("expression", select_without_ctes) 353 354 if select_expression.ctes: 355 expression.set("with", exp.With(expressions=select_expression.ctes)) 356 elif expression_type == exp.Select: 357 expression = select_expression 358 else: 359 raise ValueError(f"Invalid expression type: {expression_type}") 360 361 output_expressions.append(expression) 362 363 return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
368 @operation(Operation.SELECT) 369 def select(self, *cols, **kwargs) -> DataFrame: 370 cols = self._ensure_and_normalize_cols(cols) 371 kwargs["append"] = kwargs.get("append", False) 372 if self.expression.args.get("joins"): 373 ambiguous_cols = [ 374 col 375 for col in cols 376 if isinstance(col.column_expression, exp.Column) and not col.column_expression.table 377 ] 378 if ambiguous_cols: 379 join_table_identifiers = [ 380 x.this for x in get_tables_from_expression_with_join(self.expression) 381 ] 382 cte_names_in_join = [x.this for x in join_table_identifiers] 383 # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right 384 # and therefore we allow multiple columns with the same name in the result. This matches the behavior 385 # of Spark. 386 resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} 387 for ambiguous_col in ambiguous_cols: 388 ctes_with_column = [ 389 cte 390 for cte in self.expression.ctes 391 if cte.alias_or_name in cte_names_in_join 392 and ambiguous_col.alias_or_name in cte.this.named_selects 393 ] 394 # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, 395 # use the same CTE we used before 396 cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) 397 if cte: 398 resolved_column_position[ambiguous_col] += 1 399 else: 400 cte = ctes_with_column[resolved_column_position[ambiguous_col]] 401 ambiguous_col.expression.set("table", cte.alias_or_name) 402 return self.copy( 403 expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs 404 )
406 @operation(Operation.NO_OP) 407 def alias(self, name: str, **kwargs) -> DataFrame: 408 new_sequence_id = self.spark._random_sequence_id 409 df = self.copy() 410 for join_hint in df.pending_join_hints: 411 for expression in join_hint.expressions: 412 if expression.alias_or_name == self.sequence_id: 413 expression.set("this", Column.ensure_col(new_sequence_id).expression) 414 df.spark._add_alias_to_mapping(name, new_sequence_id) 415 return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
434 @operation(Operation.FROM) 435 def join( 436 self, 437 other_df: DataFrame, 438 on: t.Union[str, t.List[str], Column, t.List[Column]], 439 how: str = "inner", 440 **kwargs, 441 ) -> DataFrame: 442 other_df = other_df._convert_leaf_to_cte() 443 join_columns = self._ensure_list_of_columns(on) 444 # We will determine actual "join on" expression later so we don't provide it at first 445 join_expression = self.expression.join( 446 other_df.latest_cte_name, join_type=how.replace("_", " ") 447 ) 448 join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) 449 self_columns = self._get_outer_select_columns(join_expression) 450 other_columns = self._get_outer_select_columns(other_df) 451 # Determines the join clause and select columns to be used passed on what type of columns were provided for 452 # the join. The columns returned changes based on how the on expression is provided. 453 if isinstance(join_columns[0].expression, exp.Column): 454 """ 455 Unique characteristics of join on column names only: 456 * The column names are put at the front of the select list 457 * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) 458 """ 459 table_names = [ 460 table.alias_or_name 461 for table in get_tables_from_expression_with_join(join_expression) 462 ] 463 potential_ctes = [ 464 cte 465 for cte in join_expression.ctes 466 if cte.alias_or_name in table_names 467 and cte.alias_or_name != other_df.latest_cte_name 468 ] 469 # Determine the table to reference for the left side of the join by checking each of the left side 470 # tables and see if they have the column being referenced. 471 join_column_pairs = [] 472 for join_column in join_columns: 473 num_matching_ctes = 0 474 for cte in potential_ctes: 475 if join_column.alias_or_name in cte.this.named_selects: 476 left_column = join_column.copy().set_table_name(cte.alias_or_name) 477 right_column = join_column.copy().set_table_name(other_df.latest_cte_name) 478 join_column_pairs.append((left_column, right_column)) 479 num_matching_ctes += 1 480 if num_matching_ctes > 1: 481 raise ValueError( 482 f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." 483 ) 484 elif num_matching_ctes == 0: 485 raise ValueError( 486 f"Column {join_column.alias_or_name} does not exist in any of the tables." 487 ) 488 join_clause = functools.reduce( 489 lambda x, y: x & y, 490 [left_column == right_column for left_column, right_column in join_column_pairs], 491 ) 492 join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] 493 # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list 494 select_column_names = [ 495 ( 496 column.alias_or_name 497 if not isinstance(column.expression.this, exp.Star) 498 else column.sql() 499 ) 500 for column in self_columns + other_columns 501 ] 502 select_column_names = [ 503 column_name 504 for column_name in select_column_names 505 if column_name not in join_column_names 506 ] 507 select_column_names = join_column_names + select_column_names 508 else: 509 """ 510 Unique characteristics of join on expressions: 511 * There is no deduplication of the results. 512 * The left join dataframe columns go first and right come after. No sort preference is given to join columns 513 """ 514 join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) 515 if len(join_columns) > 1: 516 join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] 517 join_clause = join_columns[0] 518 select_column_names = [column.alias_or_name for column in self_columns + other_columns] 519 520 # Update the on expression with the actual join clause to replace the dummy one from before 521 join_expression.args["joins"][-1].set("on", join_clause.expression) 522 new_df = self.copy(expression=join_expression) 523 new_df.pending_join_hints.extend(self.pending_join_hints) 524 new_df.pending_hints.extend(other_df.pending_hints) 525 new_df = new_df.select.__wrapped__(new_df, *select_column_names) 526 return new_df
528 @operation(Operation.ORDER_BY) 529 def orderBy( 530 self, 531 *cols: t.Union[str, Column], 532 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 533 ) -> DataFrame: 534 """ 535 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 536 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 537 is unlikely to come up. 538 """ 539 columns = self._ensure_and_normalize_cols(cols) 540 pre_ordered_col_indexes = [ 541 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 542 ] 543 if ascending is None: 544 ascending = [True] * len(columns) 545 elif not isinstance(ascending, list): 546 ascending = [ascending] * len(columns) 547 ascending = [bool(x) for i, x in enumerate(ascending)] 548 assert len(columns) == len( 549 ascending 550 ), "The length of items in ascending must equal the number of columns provided" 551 col_and_ascending = list(zip(columns, ascending)) 552 order_by_columns = [ 553 ( 554 exp.Ordered(this=col.expression, desc=not asc) 555 if i not in pre_ordered_col_indexes 556 else columns[i].column_expression 557 ) 558 for i, (col, asc) in enumerate(col_and_ascending) 559 ] 560 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
528 @operation(Operation.ORDER_BY) 529 def orderBy( 530 self, 531 *cols: t.Union[str, Column], 532 ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, 533 ) -> DataFrame: 534 """ 535 This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark 536 has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this 537 is unlikely to come up. 538 """ 539 columns = self._ensure_and_normalize_cols(cols) 540 pre_ordered_col_indexes = [ 541 i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) 542 ] 543 if ascending is None: 544 ascending = [True] * len(columns) 545 elif not isinstance(ascending, list): 546 ascending = [ascending] * len(columns) 547 ascending = [bool(x) for i, x in enumerate(ascending)] 548 assert len(columns) == len( 549 ascending 550 ), "The length of items in ascending must equal the number of columns provided" 551 col_and_ascending = list(zip(columns, ascending)) 552 order_by_columns = [ 553 ( 554 exp.Ordered(this=col.expression, desc=not asc) 555 if i not in pre_ordered_col_indexes 556 else columns[i].column_expression 557 ) 558 for i, (col, asc) in enumerate(col_and_ascending) 559 ] 560 return self.copy(expression=self.expression.order_by(*order_by_columns))
This implementation lets any ordered columns take priority over whatever is provided in ascending
. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
570 @operation(Operation.FROM) 571 def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): 572 l_columns = self.columns 573 r_columns = other.columns 574 if not allowMissingColumns: 575 l_expressions = l_columns 576 r_expressions = l_columns 577 else: 578 l_expressions = [] 579 r_expressions = [] 580 r_columns_unused = copy(r_columns) 581 for l_column in l_columns: 582 l_expressions.append(l_column) 583 if l_column in r_columns: 584 r_expressions.append(l_column) 585 r_columns_unused.remove(l_column) 586 else: 587 r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) 588 for r_column in r_columns_unused: 589 l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) 590 r_expressions.append(r_column) 591 r_df = ( 592 other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) 593 ) 594 l_df = self.copy() 595 if allowMissingColumns: 596 l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) 597 return l_df._set_operation(exp.Union, r_df, False)
615 @operation(Operation.SELECT) 616 def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): 617 if not subset: 618 return self.distinct() 619 column_names = ensure_list(subset) 620 window = Window.partitionBy(*column_names).orderBy(*column_names) 621 return ( 622 self.copy() 623 .withColumn("row_num", F.row_number().over(window)) 624 .where(F.col("row_num") == F.lit(1)) 625 .drop("row_num") 626 )
628 @operation(Operation.FROM) 629 def dropna( 630 self, 631 how: str = "any", 632 thresh: t.Optional[int] = None, 633 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 634 ) -> DataFrame: 635 minimum_non_null = thresh or 0 # will be determined later if thresh is null 636 new_df = self.copy() 637 all_columns = self._get_outer_select_columns(new_df.expression) 638 if subset: 639 null_check_columns = self._ensure_and_normalize_cols(subset) 640 else: 641 null_check_columns = all_columns 642 if thresh is None: 643 minimum_num_nulls = 1 if how == "any" else len(null_check_columns) 644 else: 645 minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 646 if minimum_num_nulls > len(null_check_columns): 647 raise RuntimeError( 648 f"The minimum num nulls for dropna must be less than or equal to the number of columns. " 649 f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" 650 ) 651 if_null_checks = [ 652 F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns 653 ] 654 nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) 655 num_nulls = nulls_added_together.alias("num_nulls") 656 new_df = new_df.select(num_nulls, append=True) 657 filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) 658 final_df = filtered_df.select(*all_columns) 659 return final_df
661 @operation(Operation.FROM) 662 def fillna( 663 self, 664 value: t.Union[ColumnLiterals], 665 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 666 ) -> DataFrame: 667 """ 668 Functionality Difference: If you provide a value to replace a null and that type conflicts 669 with the type of the column then PySpark will just ignore your replacement. 670 This will try to cast them to be the same in some cases. So they won't always match. 671 Best to not mix types so make sure replacement is the same type as the column 672 673 Possibility for improvement: Use `typeof` function to get the type of the column 674 and check if it matches the type of the value provided. If not then make it null. 675 """ 676 from sqlglot.dataframe.sql.functions import lit 677 678 values = None 679 columns = None 680 new_df = self.copy() 681 all_columns = self._get_outer_select_columns(new_df.expression) 682 all_column_mapping = {column.alias_or_name: column for column in all_columns} 683 if isinstance(value, dict): 684 values = list(value.values()) 685 columns = self._ensure_and_normalize_cols(list(value)) 686 if not columns: 687 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 688 if not values: 689 values = [value] * len(columns) 690 value_columns = [lit(value) for value in values] 691 692 null_replacement_mapping = { 693 column.alias_or_name: ( 694 F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) 695 ) 696 for column, value in zip(columns, value_columns) 697 } 698 null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} 699 null_replacement_columns = [ 700 null_replacement_mapping[column.alias_or_name] for column in all_columns 701 ] 702 new_df = new_df.select(*null_replacement_columns) 703 return new_df
Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use typeof
function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
705 @operation(Operation.FROM) 706 def replace( 707 self, 708 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 709 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 710 subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, 711 ) -> DataFrame: 712 from sqlglot.dataframe.sql.functions import lit 713 714 old_values = None 715 new_df = self.copy() 716 all_columns = self._get_outer_select_columns(new_df.expression) 717 all_column_mapping = {column.alias_or_name: column for column in all_columns} 718 719 columns = self._ensure_and_normalize_cols(subset) if subset else all_columns 720 if isinstance(to_replace, dict): 721 old_values = list(to_replace) 722 new_values = list(to_replace.values()) 723 elif not old_values and isinstance(to_replace, list): 724 assert isinstance(value, list), "value must be a list since the replacements are a list" 725 assert len(to_replace) == len( 726 value 727 ), "the replacements and values must be the same length" 728 old_values = to_replace 729 new_values = value 730 else: 731 old_values = [to_replace] * len(columns) 732 new_values = [value] * len(columns) 733 old_values = [lit(value) for value in old_values] 734 new_values = [lit(value) for value in new_values] 735 736 replacement_mapping = {} 737 for column in columns: 738 expression = Column(None) 739 for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): 740 if i == 0: 741 expression = F.when(column == old_value, new_value) 742 else: 743 expression = expression.when(column == old_value, new_value) # type: ignore 744 replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( 745 column.expression.alias_or_name 746 ) 747 748 replacement_mapping = {**all_column_mapping, **replacement_mapping} 749 replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] 750 new_df = new_df.select(*replacement_columns) 751 return new_df
753 @operation(Operation.SELECT) 754 def withColumn(self, colName: str, col: Column) -> DataFrame: 755 col = self._ensure_and_normalize_col(col) 756 existing_col_names = self.expression.named_selects 757 existing_col_index = ( 758 existing_col_names.index(colName) if colName in existing_col_names else None 759 ) 760 if existing_col_index: 761 expression = self.expression.copy() 762 expression.expressions[existing_col_index] = col.expression 763 return self.copy(expression=expression) 764 return self.copy().select(col.alias(colName), append=True)
766 @operation(Operation.SELECT) 767 def withColumnRenamed(self, existing: str, new: str): 768 expression = self.expression.copy() 769 existing_columns = [ 770 expression 771 for expression in expression.expressions 772 if expression.alias_or_name == existing 773 ] 774 if not existing_columns: 775 raise ValueError("Tried to rename a column that doesn't exist") 776 for existing_column in existing_columns: 777 if isinstance(existing_column, exp.Column): 778 existing_column.replace(exp.alias_(existing_column, new)) 779 else: 780 existing_column.set("alias", exp.to_identifier(new)) 781 return self.copy(expression=expression)
783 @operation(Operation.SELECT) 784 def drop(self, *cols: t.Union[str, Column]) -> DataFrame: 785 all_columns = self._get_outer_select_columns(self.expression) 786 drop_cols = self._ensure_and_normalize_cols(cols) 787 new_columns = [ 788 col 789 for col in all_columns 790 if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] 791 ] 792 return self.copy().select(*new_columns, append=False)
798 @operation(Operation.NO_OP) 799 def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: 800 parameter_list = ensure_list(parameters) 801 parameter_columns = ( 802 self._ensure_list_of_columns(parameter_list) 803 if parameters 804 else Column.ensure_cols([self.sequence_id]) 805 ) 806 return self._hint(name, parameter_columns)
808 @operation(Operation.NO_OP) 809 def repartition( 810 self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName 811 ) -> DataFrame: 812 num_partition_cols = self._ensure_list_of_columns(numPartitions) 813 columns = self._ensure_and_normalize_cols(cols) 814 args = num_partition_cols + columns 815 return self._hint("repartition", args)
826 @operation(Operation.NO_OP) 827 def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: 828 """ 829 Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html 830 """ 831 return self._cache(storageLevel)
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
14class GroupedData: 15 def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): 16 self._df = df.copy() 17 self.spark = df.spark 18 self.last_op = last_op 19 self.group_by_cols = group_by_cols 20 21 def _get_function_applied_columns( 22 self, func_name: str, cols: t.Tuple[str, ...] 23 ) -> t.List[Column]: 24 func_name = func_name.lower() 25 return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] 26 27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression) 40 41 def count(self) -> DataFrame: 42 return self.agg(F.count("*").alias("count")) 43 44 def mean(self, *cols: str) -> DataFrame: 45 return self.avg(*cols) 46 47 def avg(self, *cols: str) -> DataFrame: 48 return self.agg(*self._get_function_applied_columns("avg", cols)) 49 50 def max(self, *cols: str) -> DataFrame: 51 return self.agg(*self._get_function_applied_columns("max", cols)) 52 53 def min(self, *cols: str) -> DataFrame: 54 return self.agg(*self._get_function_applied_columns("min", cols)) 55 56 def sum(self, *cols: str) -> DataFrame: 57 return self.agg(*self._get_function_applied_columns("sum", cols)) 58 59 def pivot(self, *cols: str) -> DataFrame: 60 raise NotImplementedError("Sum distinct is not currently implemented")
27 @operation(Operation.SELECT) 28 def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: 29 columns = ( 30 [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] 31 if isinstance(exprs[0], dict) 32 else exprs 33 ) 34 cols = self._df._ensure_and_normalize_cols(columns) 35 36 expression = self._df.expression.group_by( 37 *[x.expression for x in self.group_by_cols] 38 ).select(*[x.expression for x in self.group_by_cols + cols], append=False) 39 return self._df.copy(expression=expression)
16class Column: 17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore 32 33 def __repr__(self): 34 return repr(self.expression) 35 36 def __hash__(self): 37 return hash(self.expression) 38 39 def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore 40 return self.binary_op(exp.EQ, other) 41 42 def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore 43 return self.binary_op(exp.NEQ, other) 44 45 def __gt__(self, other: ColumnOrLiteral) -> Column: 46 return self.binary_op(exp.GT, other) 47 48 def __ge__(self, other: ColumnOrLiteral) -> Column: 49 return self.binary_op(exp.GTE, other) 50 51 def __lt__(self, other: ColumnOrLiteral) -> Column: 52 return self.binary_op(exp.LT, other) 53 54 def __le__(self, other: ColumnOrLiteral) -> Column: 55 return self.binary_op(exp.LTE, other) 56 57 def __and__(self, other: ColumnOrLiteral) -> Column: 58 return self.binary_op(exp.And, other) 59 60 def __or__(self, other: ColumnOrLiteral) -> Column: 61 return self.binary_op(exp.Or, other) 62 63 def __mod__(self, other: ColumnOrLiteral) -> Column: 64 return self.binary_op(exp.Mod, other) 65 66 def __add__(self, other: ColumnOrLiteral) -> Column: 67 return self.binary_op(exp.Add, other) 68 69 def __sub__(self, other: ColumnOrLiteral) -> Column: 70 return self.binary_op(exp.Sub, other) 71 72 def __mul__(self, other: ColumnOrLiteral) -> Column: 73 return self.binary_op(exp.Mul, other) 74 75 def __truediv__(self, other: ColumnOrLiteral) -> Column: 76 return self.binary_op(exp.Div, other) 77 78 def __div__(self, other: ColumnOrLiteral) -> Column: 79 return self.binary_op(exp.Div, other) 80 81 def __neg__(self) -> Column: 82 return self.unary_op(exp.Neg) 83 84 def __radd__(self, other: ColumnOrLiteral) -> Column: 85 return self.inverse_binary_op(exp.Add, other) 86 87 def __rsub__(self, other: ColumnOrLiteral) -> Column: 88 return self.inverse_binary_op(exp.Sub, other) 89 90 def __rmul__(self, other: ColumnOrLiteral) -> Column: 91 return self.inverse_binary_op(exp.Mul, other) 92 93 def __rdiv__(self, other: ColumnOrLiteral) -> Column: 94 return self.inverse_binary_op(exp.Div, other) 95 96 def __rtruediv__(self, other: ColumnOrLiteral) -> Column: 97 return self.inverse_binary_op(exp.Div, other) 98 99 def __rmod__(self, other: ColumnOrLiteral) -> Column: 100 return self.inverse_binary_op(exp.Mod, other) 101 102 def __pow__(self, power: ColumnOrLiteral, modulo=None): 103 return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) 104 105 def __rpow__(self, power: ColumnOrLiteral): 106 return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) 107 108 def __invert__(self): 109 return self.unary_op(exp.Not) 110 111 def __rand__(self, other: ColumnOrLiteral) -> Column: 112 return self.inverse_binary_op(exp.And, other) 113 114 def __ror__(self, other: ColumnOrLiteral) -> Column: 115 return self.inverse_binary_op(exp.Or, other) 116 117 @classmethod 118 def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: 119 return cls(value) 120 121 @classmethod 122 def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: 123 return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] 124 125 @classmethod 126 def _lit(cls, value: ColumnOrLiteral) -> Column: 127 if isinstance(value, dict): 128 columns = [cls._lit(v).alias(k).expression for k, v in value.items()] 129 return cls(exp.Struct(expressions=columns)) 130 return cls(exp.convert(value)) 131 132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression) 141 142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression) 164 165 def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 166 return Column( 167 klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) 168 ) 169 170 def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: 171 return Column( 172 klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) 173 ) 174 175 def unary_op(self, klass: t.Callable, **kwargs) -> Column: 176 return Column(klass(this=self.column_expression, **kwargs)) 177 178 @property 179 def is_alias(self): 180 return isinstance(self.expression, exp.Alias) 181 182 @property 183 def is_column(self): 184 return isinstance(self.expression, exp.Column) 185 186 @property 187 def column_expression(self) -> t.Union[exp.Column, exp.Literal]: 188 return self.expression.unalias() 189 190 @property 191 def alias_or_name(self) -> str: 192 return self.expression.alias_or_name 193 194 @classmethod 195 def ensure_literal(cls, value) -> Column: 196 from sqlglot.dataframe.sql.functions import lit 197 198 if isinstance(value, cls): 199 value = value.expression 200 if not isinstance(value, exp.Literal): 201 return lit(value) 202 return Column(value) 203 204 def copy(self) -> Column: 205 return Column(self.expression.copy()) 206 207 def set_table_name(self, table_name: str, copy=False) -> Column: 208 expression = self.expression.copy() if copy else self.expression 209 expression.set("table", exp.to_identifier(table_name)) 210 return Column(expression) 211 212 def sql(self, **kwargs) -> str: 213 from sqlglot.dataframe.sql.session import SparkSession 214 215 return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) 216 217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression) 228 229 def asc(self) -> Column: 230 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) 231 return Column(new_expression) 232 233 def desc(self) -> Column: 234 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) 235 return Column(new_expression) 236 237 asc_nulls_first = asc 238 239 def asc_nulls_last(self) -> Column: 240 new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) 241 return Column(new_expression) 242 243 def desc_nulls_first(self) -> Column: 244 new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) 245 return Column(new_expression) 246 247 desc_nulls_last = desc 248 249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column 258 259 def otherwise(self, value: t.Any) -> Column: 260 from sqlglot.dataframe.sql.functions import lit 261 262 true_value = value if isinstance(value, Column) else lit(value) 263 new_column = self.copy() 264 new_column.expression.set("default", true_value.column_expression) 265 return new_column 266 267 def isNull(self) -> Column: 268 new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) 269 return Column(new_expression) 270 271 def isNotNull(self) -> Column: 272 new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) 273 return Column(new_expression) 274 275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) 285 286 def startswith(self, value: t.Union[str, Column]) -> Column: 287 value = self._lit(value) if not isinstance(value, Column) else value 288 return self.invoke_anonymous_function(self, "STARTSWITH", value) 289 290 def endswith(self, value: t.Union[str, Column]) -> Column: 291 value = self._lit(value) if not isinstance(value, Column) else value 292 return self.invoke_anonymous_function(self, "ENDSWITH", value) 293 294 def rlike(self, regexp: str) -> Column: 295 return self.invoke_expression_over_column( 296 column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression 297 ) 298 299 def like(self, other: str): 300 return self.invoke_expression_over_column( 301 self, exp.Like, expression=self._lit(other).expression 302 ) 303 304 def ilike(self, other: str): 305 return self.invoke_expression_over_column( 306 self, exp.ILike, expression=self._lit(other).expression 307 ) 308 309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 ) 315 316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore 320 321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 ) 339 340 def over(self, window: WindowSpec) -> Column: 341 window_expression = window.expression.copy() 342 window_expression.set("this", self.column_expression) 343 return Column(window_expression)
17 def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): 18 from sqlglot.dataframe.sql.session import SparkSession 19 20 if isinstance(expression, Column): 21 expression = expression.expression # type: ignore 22 elif expression is None or not isinstance(expression, (str, exp.Expression)): 23 expression = self._lit(expression).expression # type: ignore 24 elif not isinstance(expression, exp.Column): 25 expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( 26 SparkSession().dialect.normalize_identifier, copy=False 27 ) 28 if expression is None: 29 raise ValueError(f"Could not parse {expression}") 30 31 self.expression: exp.Expression = expression # type: ignore
132 @classmethod 133 def invoke_anonymous_function( 134 cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] 135 ) -> Column: 136 columns = [] if column is None else [cls.ensure_col(column)] 137 column_args = [cls.ensure_col(arg) for arg in args] 138 expressions = [x.expression for x in columns + column_args] 139 new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) 140 return Column(new_expression)
142 @classmethod 143 def invoke_expression_over_column( 144 cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs 145 ) -> Column: 146 ensured_column = None if column is None else cls.ensure_col(column) 147 ensure_expression_values = { 148 k: ( 149 [Column.ensure_col(x).expression for x in v] 150 if is_iterable(v) 151 else Column.ensure_col(v).expression 152 ) 153 for k, v in kwargs.items() 154 if v is not None 155 } 156 new_expression = ( 157 callable_expression(**ensure_expression_values) 158 if ensured_column is None 159 else callable_expression( 160 this=ensured_column.column_expression, **ensure_expression_values 161 ) 162 ) 163 return Column(new_expression)
217 def alias(self, name: str) -> Column: 218 from sqlglot.dataframe.sql.session import SparkSession 219 220 dialect = SparkSession().dialect 221 alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect) 222 new_expression = exp.alias_( 223 self.column_expression, 224 alias.this if isinstance(alias, exp.Column) else name, 225 dialect=dialect, 226 ) 227 return Column(new_expression)
249 def when(self, condition: Column, value: t.Any) -> Column: 250 from sqlglot.dataframe.sql.functions import when 251 252 column_with_if = when(condition, value) 253 if not isinstance(self.expression, exp.Case): 254 return column_with_if 255 new_column = self.copy() 256 new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) 257 return new_column
275 def cast(self, dataType: t.Union[str, DataType]) -> Column: 276 """ 277 Functionality Difference: PySpark cast accepts a datatype instance of the datatype class 278 Sqlglot doesn't currently replicate this class so it only accepts a string 279 """ 280 from sqlglot.dataframe.sql.session import SparkSession 281 282 if isinstance(dataType, DataType): 283 dataType = dataType.simpleString() 284 return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string
309 def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: 310 startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos 311 length = self._lit(length) if not isinstance(length, Column) else length 312 return Column.invoke_expression_over_column( 313 self, exp.Substring, start=startPos.expression, length=length.expression 314 )
316 def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): 317 columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 318 expressions = [self._lit(x).expression for x in columns] 319 return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
321 def between( 322 self, 323 lowerBound: t.Union[ColumnOrLiteral], 324 upperBound: t.Union[ColumnOrLiteral], 325 ) -> Column: 326 lower_bound_exp = ( 327 self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound 328 ) 329 upper_bound_exp = ( 330 self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound 331 ) 332 return Column( 333 exp.Between( 334 this=self.column_expression, 335 low=lower_bound_exp.expression, 336 high=upper_bound_exp.expression, 337 ) 338 )
834class DataFrameNaFunctions: 835 def __init__(self, df: DataFrame): 836 self.df = df 837 838 def drop( 839 self, 840 how: str = "any", 841 thresh: t.Optional[int] = None, 842 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 843 ) -> DataFrame: 844 return self.df.dropna(how=how, thresh=thresh, subset=subset) 845 846 def fill( 847 self, 848 value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], 849 subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, 850 ) -> DataFrame: 851 return self.df.fillna(value=value, subset=subset) 852 853 def replace( 854 self, 855 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 856 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 857 subset: t.Optional[t.Union[str, t.List[str]]] = None, 858 ) -> DataFrame: 859 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
853 def replace( 854 self, 855 to_replace: t.Union[bool, int, float, str, t.List, t.Dict], 856 value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, 857 subset: t.Optional[t.Union[str, t.List[str]]] = None, 858 ) -> DataFrame: 859 return self.df.replace(to_replace=to_replace, value=value, subset=subset)
15class Window: 16 _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 17 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 18 _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) 19 _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) 20 21 unboundedPreceding: int = _JAVA_MIN_LONG 22 23 unboundedFollowing: int = _JAVA_MAX_LONG 24 25 currentRow: int = 0 26 27 @classmethod 28 def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 29 return WindowSpec().partitionBy(*cols) 30 31 @classmethod 32 def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 33 return WindowSpec().orderBy(*cols) 34 35 @classmethod 36 def rowsBetween(cls, start: int, end: int) -> WindowSpec: 37 return WindowSpec().rowsBetween(start, end) 38 39 @classmethod 40 def rangeBetween(cls, start: int, end: int) -> WindowSpec: 41 return WindowSpec().rangeBetween(start, end)
44class WindowSpec: 45 def __init__(self, expression: exp.Expression = exp.Window()): 46 self.expression = expression 47 48 def copy(self): 49 return WindowSpec(self.expression.copy()) 50 51 def sql(self, **kwargs) -> str: 52 from sqlglot.dataframe.sql.session import SparkSession 53 54 return self.expression.sql(dialect=SparkSession().dialect, **kwargs) 55 56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec 66 67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec 79 80 def _calc_start_end( 81 self, start: int, end: int 82 ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: 83 kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { 84 "start_side": None, 85 "end_side": None, 86 } 87 if start == Window.currentRow: 88 kwargs["start"] = "CURRENT ROW" 89 else: 90 kwargs = { 91 **kwargs, 92 **{ 93 "start_side": "PRECEDING", 94 "start": ( 95 "UNBOUNDED" 96 if start <= Window.unboundedPreceding 97 else F.lit(start).expression 98 ), 99 }, 100 } 101 if end == Window.currentRow: 102 kwargs["end"] = "CURRENT ROW" 103 else: 104 kwargs = { 105 **kwargs, 106 **{ 107 "end_side": "FOLLOWING", 108 "end": ( 109 "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression 110 ), 111 }, 112 } 113 return kwargs 114 115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec 126 127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
56 def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 57 from sqlglot.dataframe.sql.column import Column 58 59 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 60 expressions = [Column.ensure_col(x).expression for x in cols] 61 window_spec = self.copy() 62 partition_by_expressions = window_spec.expression.args.get("partition_by", []) 63 partition_by_expressions.extend(expressions) 64 window_spec.expression.set("partition_by", partition_by_expressions) 65 return window_spec
67 def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: 68 from sqlglot.dataframe.sql.column import Column 69 70 cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore 71 expressions = [Column.ensure_col(x).expression for x in cols] 72 window_spec = self.copy() 73 if window_spec.expression.args.get("order") is None: 74 window_spec.expression.set("order", exp.Order(expressions=[])) 75 order_by = window_spec.expression.args["order"].expressions 76 order_by.extend(expressions) 77 window_spec.expression.args["order"].set("expressions", order_by) 78 return window_spec
115 def rowsBetween(self, start: int, end: int) -> WindowSpec: 116 window_spec = self.copy() 117 spec = self._calc_start_end(start, end) 118 spec["kind"] = "ROWS" 119 window_spec.expression.set( 120 "spec", 121 exp.WindowSpec( 122 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 123 ), 124 ) 125 return window_spec
127 def rangeBetween(self, start: int, end: int) -> WindowSpec: 128 window_spec = self.copy() 129 spec = self._calc_start_end(start, end) 130 spec["kind"] = "RANGE" 131 window_spec.expression.set( 132 "spec", 133 exp.WindowSpec( 134 **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} 135 ), 136 ) 137 return window_spec
15class DataFrameReader: 16 def __init__(self, spark: SparkSession): 17 self.spark = spark 18 19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
19 def table(self, tableName: str) -> DataFrame: 20 from sqlglot.dataframe.sql.dataframe import DataFrame 21 from sqlglot.dataframe.sql.session import SparkSession 22 23 sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) 24 25 return DataFrame( 26 self.spark, 27 exp.Select() 28 .from_( 29 exp.to_table(tableName, dialect=SparkSession().dialect).transform( 30 SparkSession().dialect.normalize_identifier 31 ) 32 ) 33 .select( 34 *( 35 column 36 for column in sqlglot.schema.column_names( 37 tableName, dialect=SparkSession().dialect 38 ) 39 ) 40 ), 41 )
44class DataFrameWriter: 45 def __init__( 46 self, 47 df: DataFrame, 48 spark: t.Optional[SparkSession] = None, 49 mode: t.Optional[str] = None, 50 by_name: bool = False, 51 ): 52 self._df = df 53 self._spark = spark or df.spark 54 self._mode = mode 55 self._by_name = by_name 56 57 def copy(self, **kwargs) -> DataFrameWriter: 58 return DataFrameWriter( 59 **{ 60 k[1:] if k.startswith("_") else k: v 61 for k, v in object_to_dict(self, **kwargs).items() 62 } 63 ) 64 65 def sql(self, **kwargs) -> t.List[str]: 66 return self._df.sql(**kwargs) 67 68 def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: 69 return self.copy(_mode=saveMode) 70 71 @property 72 def byName(self): 73 return self.copy(by_name=True) 74 75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df) 92 93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
75 def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: 76 from sqlglot.dataframe.sql.session import SparkSession 77 78 output_expression_container = exp.Insert( 79 **{ 80 "this": exp.to_table(tableName), 81 "overwrite": overwrite, 82 } 83 ) 84 df = self._df.copy(output_expression_container=output_expression_container) 85 if self._by_name: 86 columns = sqlglot.schema.column_names( 87 tableName, only_visible=True, dialect=SparkSession().dialect 88 ) 89 df = df._convert_leaf_to_cte().select(*columns) 90 91 return self.copy(_df=df)
93 def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): 94 if format is not None: 95 raise NotImplementedError("Providing Format in the save as table is not supported") 96 exists, replace, mode = None, None, mode or str(self._mode) 97 if mode == "append": 98 return self.insertInto(name) 99 if mode == "ignore": 100 exists = True 101 if mode == "overwrite": 102 replace = True 103 output_expression_container = exp.Create( 104 this=exp.to_table(name), 105 kind="TABLE", 106 exists=exists, 107 replace=replace, 108 ) 109 return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))