sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18TABLE_ARGS = ("this", "db", "catalog") 19 20 21class Schema(abc.ABC): 22 """Abstract base class for database schemas""" 23 24 dialect: DialectType 25 26 @abc.abstractmethod 27 def add_table( 28 self, 29 table: exp.Table | str, 30 column_mapping: t.Optional[ColumnMapping] = None, 31 dialect: DialectType = None, 32 normalize: t.Optional[bool] = None, 33 match_depth: bool = True, 34 ) -> None: 35 """ 36 Register or update a table. Some implementing classes may require column information to also be provided. 37 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 38 39 Args: 40 table: the `Table` expression instance or string representing the table. 41 column_mapping: a column mapping that describes the structure of the table. 42 dialect: the SQL dialect that will be used to parse `table` if it's a string. 43 normalize: whether to normalize identifiers according to the dialect of interest. 44 match_depth: whether to enforce that the table must match the schema's depth or not. 45 """ 46 47 @abc.abstractmethod 48 def column_names( 49 self, 50 table: exp.Table | str, 51 only_visible: bool = False, 52 dialect: DialectType = None, 53 normalize: t.Optional[bool] = None, 54 ) -> t.List[str]: 55 """ 56 Get the column names for a table. 57 58 Args: 59 table: the `Table` expression instance. 60 only_visible: whether to include invisible columns. 61 dialect: the SQL dialect that will be used to parse `table` if it's a string. 62 normalize: whether to normalize identifiers according to the dialect of interest. 63 64 Returns: 65 The list of column names. 66 """ 67 68 @abc.abstractmethod 69 def get_column_type( 70 self, 71 table: exp.Table | str, 72 column: exp.Column | str, 73 dialect: DialectType = None, 74 normalize: t.Optional[bool] = None, 75 ) -> exp.DataType: 76 """ 77 Get the `sqlglot.exp.DataType` type of a column in the schema. 78 79 Args: 80 table: the source table. 81 column: the target column. 82 dialect: the SQL dialect that will be used to parse `table` if it's a string. 83 normalize: whether to normalize identifiers according to the dialect of interest. 84 85 Returns: 86 The resulting column type. 87 """ 88 89 def has_column( 90 self, 91 table: exp.Table | str, 92 column: exp.Column | str, 93 dialect: DialectType = None, 94 normalize: t.Optional[bool] = None, 95 ) -> bool: 96 """ 97 Returns whether or not `column` appears in `table`'s schema. 98 99 Args: 100 table: the source table. 101 column: the target column. 102 dialect: the SQL dialect that will be used to parse `table` if it's a string. 103 normalize: whether to normalize identifiers according to the dialect of interest. 104 105 Returns: 106 True if the column appears in the schema, False otherwise. 107 """ 108 name = column if isinstance(column, str) else column.name 109 return name in self.column_names(table, dialect=dialect, normalize=normalize) 110 111 @property 112 @abc.abstractmethod 113 def supported_table_args(self) -> t.Tuple[str, ...]: 114 """ 115 Table arguments this schema support, e.g. `("this", "db", "catalog")` 116 """ 117 118 @property 119 def empty(self) -> bool: 120 """Returns whether or not the schema is empty.""" 121 return True 122 123 124class AbstractMappingSchema: 125 def __init__( 126 self, 127 mapping: t.Optional[t.Dict] = None, 128 ) -> None: 129 self.mapping = mapping or {} 130 self.mapping_trie = new_trie( 131 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 132 ) 133 self._supported_table_args: t.Tuple[str, ...] = tuple() 134 135 @property 136 def empty(self) -> bool: 137 return not self.mapping 138 139 def depth(self) -> int: 140 return dict_depth(self.mapping) 141 142 @property 143 def supported_table_args(self) -> t.Tuple[str, ...]: 144 if not self._supported_table_args and self.mapping: 145 depth = self.depth() 146 147 if not depth: # None 148 self._supported_table_args = tuple() 149 elif 1 <= depth <= 3: 150 self._supported_table_args = TABLE_ARGS[:depth] 151 else: 152 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 153 154 return self._supported_table_args 155 156 def table_parts(self, table: exp.Table) -> t.List[str]: 157 if isinstance(table.this, exp.ReadCSV): 158 return [table.this.name] 159 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 160 161 def find( 162 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 163 ) -> t.Optional[t.Any]: 164 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 165 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 166 167 if value == TrieResult.FAILED: 168 return None 169 170 if value == TrieResult.PREFIX: 171 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 172 173 if len(possibilities) == 1: 174 parts.extend(possibilities[0]) 175 else: 176 message = ", ".join(".".join(parts) for parts in possibilities) 177 if raise_on_missing: 178 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 179 return None 180 181 return self.nested_get(parts, raise_on_missing=raise_on_missing) 182 183 def nested_get( 184 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 185 ) -> t.Optional[t.Any]: 186 return nested_get( 187 d or self.mapping, 188 *zip(self.supported_table_args, reversed(parts)), 189 raise_on_missing=raise_on_missing, 190 ) 191 192 193class MappingSchema(AbstractMappingSchema, Schema): 194 """ 195 Schema based on a nested mapping. 196 197 Args: 198 schema: Mapping in one of the following forms: 199 1. {table: {col: type}} 200 2. {db: {table: {col: type}}} 201 3. {catalog: {db: {table: {col: type}}}} 202 4. None - Tables will be added later 203 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 204 are assumed to be visible. The nesting should mirror that of the schema: 205 1. {table: set(*cols)}} 206 2. {db: {table: set(*cols)}}} 207 3. {catalog: {db: {table: set(*cols)}}}} 208 dialect: The dialect to be used for custom type mappings & parsing string arguments. 209 normalize: Whether to normalize identifier names according to the given dialect or not. 210 """ 211 212 def __init__( 213 self, 214 schema: t.Optional[t.Dict] = None, 215 visible: t.Optional[t.Dict] = None, 216 dialect: DialectType = None, 217 normalize: bool = True, 218 ) -> None: 219 self.dialect = dialect 220 self.visible = visible or {} 221 self.normalize = normalize 222 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 223 self._depth = 0 224 225 super().__init__(self._normalize(schema or {})) 226 227 @classmethod 228 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 229 return MappingSchema( 230 schema=mapping_schema.mapping, 231 visible=mapping_schema.visible, 232 dialect=mapping_schema.dialect, 233 normalize=mapping_schema.normalize, 234 ) 235 236 def copy(self, **kwargs) -> MappingSchema: 237 return MappingSchema( 238 **{ # type: ignore 239 "schema": self.mapping.copy(), 240 "visible": self.visible.copy(), 241 "dialect": self.dialect, 242 "normalize": self.normalize, 243 **kwargs, 244 } 245 ) 246 247 def add_table( 248 self, 249 table: exp.Table | str, 250 column_mapping: t.Optional[ColumnMapping] = None, 251 dialect: DialectType = None, 252 normalize: t.Optional[bool] = None, 253 match_depth: bool = True, 254 ) -> None: 255 """ 256 Register or update a table. Updates are only performed if a new column mapping is provided. 257 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 258 259 Args: 260 table: the `Table` expression instance or string representing the table. 261 column_mapping: a column mapping that describes the structure of the table. 262 dialect: the SQL dialect that will be used to parse `table` if it's a string. 263 normalize: whether to normalize identifiers according to the dialect of interest. 264 match_depth: whether to enforce that the table must match the schema's depth or not. 265 """ 266 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 267 268 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 269 raise SchemaError( 270 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 271 f"schema's nesting level: {self.depth()}." 272 ) 273 274 normalized_column_mapping = { 275 self._normalize_name(key, dialect=dialect, normalize=normalize): value 276 for key, value in ensure_column_mapping(column_mapping).items() 277 } 278 279 schema = self.find(normalized_table, raise_on_missing=False) 280 if schema and not normalized_column_mapping: 281 return 282 283 parts = self.table_parts(normalized_table) 284 285 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 286 new_trie([parts], self.mapping_trie) 287 288 def column_names( 289 self, 290 table: exp.Table | str, 291 only_visible: bool = False, 292 dialect: DialectType = None, 293 normalize: t.Optional[bool] = None, 294 ) -> t.List[str]: 295 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 296 297 schema = self.find(normalized_table) 298 if schema is None: 299 return [] 300 301 if not only_visible or not self.visible: 302 return list(schema) 303 304 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 305 return [col for col in schema if col in visible] 306 307 def get_column_type( 308 self, 309 table: exp.Table | str, 310 column: exp.Column | str, 311 dialect: DialectType = None, 312 normalize: t.Optional[bool] = None, 313 ) -> exp.DataType: 314 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 315 316 normalized_column_name = self._normalize_name( 317 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 318 ) 319 320 table_schema = self.find(normalized_table, raise_on_missing=False) 321 if table_schema: 322 column_type = table_schema.get(normalized_column_name) 323 324 if isinstance(column_type, exp.DataType): 325 return column_type 326 elif isinstance(column_type, str): 327 return self._to_data_type(column_type, dialect=dialect) 328 329 return exp.DataType.build("unknown") 330 331 def has_column( 332 self, 333 table: exp.Table | str, 334 column: exp.Column | str, 335 dialect: DialectType = None, 336 normalize: t.Optional[bool] = None, 337 ) -> bool: 338 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 339 340 normalized_column_name = self._normalize_name( 341 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 342 ) 343 344 table_schema = self.find(normalized_table, raise_on_missing=False) 345 return normalized_column_name in table_schema if table_schema else False 346 347 def _normalize(self, schema: t.Dict) -> t.Dict: 348 """ 349 Normalizes all identifiers in the schema. 350 351 Args: 352 schema: the schema to normalize. 353 354 Returns: 355 The normalized schema mapping. 356 """ 357 normalized_mapping: t.Dict = {} 358 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 359 360 for keys in flattened_schema: 361 columns = nested_get(schema, *zip(keys, keys)) 362 363 if not isinstance(columns, dict): 364 raise SchemaError( 365 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 366 ) 367 368 normalized_keys = [ 369 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 370 ] 371 for column_name, column_type in columns.items(): 372 nested_set( 373 normalized_mapping, 374 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 375 column_type, 376 ) 377 378 return normalized_mapping 379 380 def _normalize_table( 381 self, 382 table: exp.Table | str, 383 dialect: DialectType = None, 384 normalize: t.Optional[bool] = None, 385 ) -> exp.Table: 386 normalized_table = exp.maybe_parse( 387 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 388 ) 389 390 for arg in TABLE_ARGS: 391 value = normalized_table.args.get(arg) 392 if isinstance(value, (str, exp.Identifier)): 393 normalized_table.set( 394 arg, 395 exp.to_identifier( 396 self._normalize_name( 397 value, dialect=dialect, is_table=True, normalize=normalize 398 ) 399 ), 400 ) 401 402 return normalized_table 403 404 def _normalize_name( 405 self, 406 name: str | exp.Identifier, 407 dialect: DialectType = None, 408 is_table: bool = False, 409 normalize: t.Optional[bool] = None, 410 ) -> str: 411 return normalize_name( 412 name, 413 dialect=dialect or self.dialect, 414 is_table=is_table, 415 normalize=self.normalize if normalize is None else normalize, 416 ) 417 418 def depth(self) -> int: 419 if not self.empty and not self._depth: 420 # The columns themselves are a mapping, but we don't want to include those 421 self._depth = super().depth() - 1 422 return self._depth 423 424 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 425 """ 426 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 427 428 Args: 429 schema_type: the type we want to convert. 430 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 431 432 Returns: 433 The resulting expression type. 434 """ 435 if schema_type not in self._type_mapping_cache: 436 dialect = dialect or self.dialect 437 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 438 439 try: 440 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 441 self._type_mapping_cache[schema_type] = expression 442 except AttributeError: 443 in_dialect = f" in dialect {dialect}" if dialect else "" 444 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 445 446 return self._type_mapping_cache[schema_type] 447 448 449def normalize_name( 450 identifier: str | exp.Identifier, 451 dialect: DialectType = None, 452 is_table: bool = False, 453 normalize: t.Optional[bool] = True, 454) -> str: 455 if isinstance(identifier, str): 456 identifier = exp.parse_identifier(identifier, dialect=dialect) 457 458 if not normalize: 459 return identifier.name 460 461 # This can be useful for normalize_identifier 462 identifier.meta["is_table"] = is_table 463 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 464 465 466def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 467 if isinstance(schema, Schema): 468 return schema 469 470 return MappingSchema(schema, **kwargs) 471 472 473def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 474 if mapping is None: 475 return {} 476 elif isinstance(mapping, dict): 477 return mapping 478 elif isinstance(mapping, str): 479 col_name_type_strs = [x.strip() for x in mapping.split(",")] 480 return { 481 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 482 for name_type_str in col_name_type_strs 483 } 484 # Check if mapping looks like a DataFrame StructType 485 elif hasattr(mapping, "simpleString"): 486 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 487 elif isinstance(mapping, list): 488 return {x.strip(): None for x in mapping} 489 490 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 491 492 493def flatten_schema( 494 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 495) -> t.List[t.List[str]]: 496 tables = [] 497 keys = keys or [] 498 499 for k, v in schema.items(): 500 if depth >= 2: 501 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 502 elif depth == 1: 503 tables.append(keys + [k]) 504 505 return tables 506 507 508def nested_get( 509 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 510) -> t.Optional[t.Any]: 511 """ 512 Get a value for a nested dictionary. 513 514 Args: 515 d: the dictionary to search. 516 *path: tuples of (name, key), where: 517 `key` is the key in the dictionary to get. 518 `name` is a string to use in the error if `key` isn't found. 519 520 Returns: 521 The value or None if it doesn't exist. 522 """ 523 for name, key in path: 524 d = d.get(key) # type: ignore 525 if d is None: 526 if raise_on_missing: 527 name = "table" if name == "this" else name 528 raise ValueError(f"Unknown {name}: {key}") 529 return None 530 531 return d 532 533 534def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 535 """ 536 In-place set a value for a nested dictionary 537 538 Example: 539 >>> nested_set({}, ["top_key", "second_key"], "value") 540 {'top_key': {'second_key': 'value'}} 541 542 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 543 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 544 545 Args: 546 d: dictionary to update. 547 keys: the keys that makeup the path to `value`. 548 value: the value to set in the dictionary for the given key path. 549 550 Returns: 551 The (possibly) updated dictionary. 552 """ 553 if not keys: 554 return d 555 556 if len(keys) == 1: 557 d[keys[0]] = value 558 return d 559 560 subd = d 561 for key in keys[:-1]: 562 if key not in subd: 563 subd = subd.setdefault(key, {}) 564 else: 565 subd = subd[key] 566 567 subd[keys[-1]] = value 568 return d
22class Schema(abc.ABC): 23 """Abstract base class for database schemas""" 24 25 dialect: DialectType 26 27 @abc.abstractmethod 28 def add_table( 29 self, 30 table: exp.Table | str, 31 column_mapping: t.Optional[ColumnMapping] = None, 32 dialect: DialectType = None, 33 normalize: t.Optional[bool] = None, 34 match_depth: bool = True, 35 ) -> None: 36 """ 37 Register or update a table. Some implementing classes may require column information to also be provided. 38 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 39 40 Args: 41 table: the `Table` expression instance or string representing the table. 42 column_mapping: a column mapping that describes the structure of the table. 43 dialect: the SQL dialect that will be used to parse `table` if it's a string. 44 normalize: whether to normalize identifiers according to the dialect of interest. 45 match_depth: whether to enforce that the table must match the schema's depth or not. 46 """ 47 48 @abc.abstractmethod 49 def column_names( 50 self, 51 table: exp.Table | str, 52 only_visible: bool = False, 53 dialect: DialectType = None, 54 normalize: t.Optional[bool] = None, 55 ) -> t.List[str]: 56 """ 57 Get the column names for a table. 58 59 Args: 60 table: the `Table` expression instance. 61 only_visible: whether to include invisible columns. 62 dialect: the SQL dialect that will be used to parse `table` if it's a string. 63 normalize: whether to normalize identifiers according to the dialect of interest. 64 65 Returns: 66 The list of column names. 67 """ 68 69 @abc.abstractmethod 70 def get_column_type( 71 self, 72 table: exp.Table | str, 73 column: exp.Column | str, 74 dialect: DialectType = None, 75 normalize: t.Optional[bool] = None, 76 ) -> exp.DataType: 77 """ 78 Get the `sqlglot.exp.DataType` type of a column in the schema. 79 80 Args: 81 table: the source table. 82 column: the target column. 83 dialect: the SQL dialect that will be used to parse `table` if it's a string. 84 normalize: whether to normalize identifiers according to the dialect of interest. 85 86 Returns: 87 The resulting column type. 88 """ 89 90 def has_column( 91 self, 92 table: exp.Table | str, 93 column: exp.Column | str, 94 dialect: DialectType = None, 95 normalize: t.Optional[bool] = None, 96 ) -> bool: 97 """ 98 Returns whether or not `column` appears in `table`'s schema. 99 100 Args: 101 table: the source table. 102 column: the target column. 103 dialect: the SQL dialect that will be used to parse `table` if it's a string. 104 normalize: whether to normalize identifiers according to the dialect of interest. 105 106 Returns: 107 True if the column appears in the schema, False otherwise. 108 """ 109 name = column if isinstance(column, str) else column.name 110 return name in self.column_names(table, dialect=dialect, normalize=normalize) 111 112 @property 113 @abc.abstractmethod 114 def supported_table_args(self) -> t.Tuple[str, ...]: 115 """ 116 Table arguments this schema support, e.g. `("this", "db", "catalog")` 117 """ 118 119 @property 120 def empty(self) -> bool: 121 """Returns whether or not the schema is empty.""" 122 return True
Abstract base class for database schemas
27 @abc.abstractmethod 28 def add_table( 29 self, 30 table: exp.Table | str, 31 column_mapping: t.Optional[ColumnMapping] = None, 32 dialect: DialectType = None, 33 normalize: t.Optional[bool] = None, 34 match_depth: bool = True, 35 ) -> None: 36 """ 37 Register or update a table. Some implementing classes may require column information to also be provided. 38 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 39 40 Args: 41 table: the `Table` expression instance or string representing the table. 42 column_mapping: a column mapping that describes the structure of the table. 43 dialect: the SQL dialect that will be used to parse `table` if it's a string. 44 normalize: whether to normalize identifiers according to the dialect of interest. 45 match_depth: whether to enforce that the table must match the schema's depth or not. 46 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
48 @abc.abstractmethod 49 def column_names( 50 self, 51 table: exp.Table | str, 52 only_visible: bool = False, 53 dialect: DialectType = None, 54 normalize: t.Optional[bool] = None, 55 ) -> t.List[str]: 56 """ 57 Get the column names for a table. 58 59 Args: 60 table: the `Table` expression instance. 61 only_visible: whether to include invisible columns. 62 dialect: the SQL dialect that will be used to parse `table` if it's a string. 63 normalize: whether to normalize identifiers according to the dialect of interest. 64 65 Returns: 66 The list of column names. 67 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
69 @abc.abstractmethod 70 def get_column_type( 71 self, 72 table: exp.Table | str, 73 column: exp.Column | str, 74 dialect: DialectType = None, 75 normalize: t.Optional[bool] = None, 76 ) -> exp.DataType: 77 """ 78 Get the `sqlglot.exp.DataType` type of a column in the schema. 79 80 Args: 81 table: the source table. 82 column: the target column. 83 dialect: the SQL dialect that will be used to parse `table` if it's a string. 84 normalize: whether to normalize identifiers according to the dialect of interest. 85 86 Returns: 87 The resulting column type. 88 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
90 def has_column( 91 self, 92 table: exp.Table | str, 93 column: exp.Column | str, 94 dialect: DialectType = None, 95 normalize: t.Optional[bool] = None, 96 ) -> bool: 97 """ 98 Returns whether or not `column` appears in `table`'s schema. 99 100 Args: 101 table: the source table. 102 column: the target column. 103 dialect: the SQL dialect that will be used to parse `table` if it's a string. 104 normalize: whether to normalize identifiers according to the dialect of interest. 105 106 Returns: 107 True if the column appears in the schema, False otherwise. 108 """ 109 name = column if isinstance(column, str) else column.name 110 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
125class AbstractMappingSchema: 126 def __init__( 127 self, 128 mapping: t.Optional[t.Dict] = None, 129 ) -> None: 130 self.mapping = mapping or {} 131 self.mapping_trie = new_trie( 132 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 133 ) 134 self._supported_table_args: t.Tuple[str, ...] = tuple() 135 136 @property 137 def empty(self) -> bool: 138 return not self.mapping 139 140 def depth(self) -> int: 141 return dict_depth(self.mapping) 142 143 @property 144 def supported_table_args(self) -> t.Tuple[str, ...]: 145 if not self._supported_table_args and self.mapping: 146 depth = self.depth() 147 148 if not depth: # None 149 self._supported_table_args = tuple() 150 elif 1 <= depth <= 3: 151 self._supported_table_args = TABLE_ARGS[:depth] 152 else: 153 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 154 155 return self._supported_table_args 156 157 def table_parts(self, table: exp.Table) -> t.List[str]: 158 if isinstance(table.this, exp.ReadCSV): 159 return [table.this.name] 160 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 161 162 def find( 163 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 164 ) -> t.Optional[t.Any]: 165 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 166 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 167 168 if value == TrieResult.FAILED: 169 return None 170 171 if value == TrieResult.PREFIX: 172 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 173 174 if len(possibilities) == 1: 175 parts.extend(possibilities[0]) 176 else: 177 message = ", ".join(".".join(parts) for parts in possibilities) 178 if raise_on_missing: 179 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 180 return None 181 182 return self.nested_get(parts, raise_on_missing=raise_on_missing) 183 184 def nested_get( 185 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 186 ) -> t.Optional[t.Any]: 187 return nested_get( 188 d or self.mapping, 189 *zip(self.supported_table_args, reversed(parts)), 190 raise_on_missing=raise_on_missing, 191 )
162 def find( 163 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 164 ) -> t.Optional[t.Any]: 165 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 166 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 167 168 if value == TrieResult.FAILED: 169 return None 170 171 if value == TrieResult.PREFIX: 172 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 173 174 if len(possibilities) == 1: 175 parts.extend(possibilities[0]) 176 else: 177 message = ", ".join(".".join(parts) for parts in possibilities) 178 if raise_on_missing: 179 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 180 return None 181 182 return self.nested_get(parts, raise_on_missing=raise_on_missing)
194class MappingSchema(AbstractMappingSchema, Schema): 195 """ 196 Schema based on a nested mapping. 197 198 Args: 199 schema: Mapping in one of the following forms: 200 1. {table: {col: type}} 201 2. {db: {table: {col: type}}} 202 3. {catalog: {db: {table: {col: type}}}} 203 4. None - Tables will be added later 204 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 205 are assumed to be visible. The nesting should mirror that of the schema: 206 1. {table: set(*cols)}} 207 2. {db: {table: set(*cols)}}} 208 3. {catalog: {db: {table: set(*cols)}}}} 209 dialect: The dialect to be used for custom type mappings & parsing string arguments. 210 normalize: Whether to normalize identifier names according to the given dialect or not. 211 """ 212 213 def __init__( 214 self, 215 schema: t.Optional[t.Dict] = None, 216 visible: t.Optional[t.Dict] = None, 217 dialect: DialectType = None, 218 normalize: bool = True, 219 ) -> None: 220 self.dialect = dialect 221 self.visible = visible or {} 222 self.normalize = normalize 223 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 224 self._depth = 0 225 226 super().__init__(self._normalize(schema or {})) 227 228 @classmethod 229 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 230 return MappingSchema( 231 schema=mapping_schema.mapping, 232 visible=mapping_schema.visible, 233 dialect=mapping_schema.dialect, 234 normalize=mapping_schema.normalize, 235 ) 236 237 def copy(self, **kwargs) -> MappingSchema: 238 return MappingSchema( 239 **{ # type: ignore 240 "schema": self.mapping.copy(), 241 "visible": self.visible.copy(), 242 "dialect": self.dialect, 243 "normalize": self.normalize, 244 **kwargs, 245 } 246 ) 247 248 def add_table( 249 self, 250 table: exp.Table | str, 251 column_mapping: t.Optional[ColumnMapping] = None, 252 dialect: DialectType = None, 253 normalize: t.Optional[bool] = None, 254 match_depth: bool = True, 255 ) -> None: 256 """ 257 Register or update a table. Updates are only performed if a new column mapping is provided. 258 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 259 260 Args: 261 table: the `Table` expression instance or string representing the table. 262 column_mapping: a column mapping that describes the structure of the table. 263 dialect: the SQL dialect that will be used to parse `table` if it's a string. 264 normalize: whether to normalize identifiers according to the dialect of interest. 265 match_depth: whether to enforce that the table must match the schema's depth or not. 266 """ 267 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 268 269 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 270 raise SchemaError( 271 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 272 f"schema's nesting level: {self.depth()}." 273 ) 274 275 normalized_column_mapping = { 276 self._normalize_name(key, dialect=dialect, normalize=normalize): value 277 for key, value in ensure_column_mapping(column_mapping).items() 278 } 279 280 schema = self.find(normalized_table, raise_on_missing=False) 281 if schema and not normalized_column_mapping: 282 return 283 284 parts = self.table_parts(normalized_table) 285 286 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 287 new_trie([parts], self.mapping_trie) 288 289 def column_names( 290 self, 291 table: exp.Table | str, 292 only_visible: bool = False, 293 dialect: DialectType = None, 294 normalize: t.Optional[bool] = None, 295 ) -> t.List[str]: 296 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 297 298 schema = self.find(normalized_table) 299 if schema is None: 300 return [] 301 302 if not only_visible or not self.visible: 303 return list(schema) 304 305 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 306 return [col for col in schema if col in visible] 307 308 def get_column_type( 309 self, 310 table: exp.Table | str, 311 column: exp.Column | str, 312 dialect: DialectType = None, 313 normalize: t.Optional[bool] = None, 314 ) -> exp.DataType: 315 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 316 317 normalized_column_name = self._normalize_name( 318 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 319 ) 320 321 table_schema = self.find(normalized_table, raise_on_missing=False) 322 if table_schema: 323 column_type = table_schema.get(normalized_column_name) 324 325 if isinstance(column_type, exp.DataType): 326 return column_type 327 elif isinstance(column_type, str): 328 return self._to_data_type(column_type, dialect=dialect) 329 330 return exp.DataType.build("unknown") 331 332 def has_column( 333 self, 334 table: exp.Table | str, 335 column: exp.Column | str, 336 dialect: DialectType = None, 337 normalize: t.Optional[bool] = None, 338 ) -> bool: 339 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 340 341 normalized_column_name = self._normalize_name( 342 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 343 ) 344 345 table_schema = self.find(normalized_table, raise_on_missing=False) 346 return normalized_column_name in table_schema if table_schema else False 347 348 def _normalize(self, schema: t.Dict) -> t.Dict: 349 """ 350 Normalizes all identifiers in the schema. 351 352 Args: 353 schema: the schema to normalize. 354 355 Returns: 356 The normalized schema mapping. 357 """ 358 normalized_mapping: t.Dict = {} 359 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 360 361 for keys in flattened_schema: 362 columns = nested_get(schema, *zip(keys, keys)) 363 364 if not isinstance(columns, dict): 365 raise SchemaError( 366 f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." 367 ) 368 369 normalized_keys = [ 370 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 371 ] 372 for column_name, column_type in columns.items(): 373 nested_set( 374 normalized_mapping, 375 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 376 column_type, 377 ) 378 379 return normalized_mapping 380 381 def _normalize_table( 382 self, 383 table: exp.Table | str, 384 dialect: DialectType = None, 385 normalize: t.Optional[bool] = None, 386 ) -> exp.Table: 387 normalized_table = exp.maybe_parse( 388 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 389 ) 390 391 for arg in TABLE_ARGS: 392 value = normalized_table.args.get(arg) 393 if isinstance(value, (str, exp.Identifier)): 394 normalized_table.set( 395 arg, 396 exp.to_identifier( 397 self._normalize_name( 398 value, dialect=dialect, is_table=True, normalize=normalize 399 ) 400 ), 401 ) 402 403 return normalized_table 404 405 def _normalize_name( 406 self, 407 name: str | exp.Identifier, 408 dialect: DialectType = None, 409 is_table: bool = False, 410 normalize: t.Optional[bool] = None, 411 ) -> str: 412 return normalize_name( 413 name, 414 dialect=dialect or self.dialect, 415 is_table=is_table, 416 normalize=self.normalize if normalize is None else normalize, 417 ) 418 419 def depth(self) -> int: 420 if not self.empty and not self._depth: 421 # The columns themselves are a mapping, but we don't want to include those 422 self._depth = super().depth() - 1 423 return self._depth 424 425 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 426 """ 427 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 428 429 Args: 430 schema_type: the type we want to convert. 431 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 432 433 Returns: 434 The resulting expression type. 435 """ 436 if schema_type not in self._type_mapping_cache: 437 dialect = dialect or self.dialect 438 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 439 440 try: 441 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 442 self._type_mapping_cache[schema_type] = expression 443 except AttributeError: 444 in_dialect = f" in dialect {dialect}" if dialect else "" 445 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 446 447 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
213 def __init__( 214 self, 215 schema: t.Optional[t.Dict] = None, 216 visible: t.Optional[t.Dict] = None, 217 dialect: DialectType = None, 218 normalize: bool = True, 219 ) -> None: 220 self.dialect = dialect 221 self.visible = visible or {} 222 self.normalize = normalize 223 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 224 self._depth = 0 225 226 super().__init__(self._normalize(schema or {}))
248 def add_table( 249 self, 250 table: exp.Table | str, 251 column_mapping: t.Optional[ColumnMapping] = None, 252 dialect: DialectType = None, 253 normalize: t.Optional[bool] = None, 254 match_depth: bool = True, 255 ) -> None: 256 """ 257 Register or update a table. Updates are only performed if a new column mapping is provided. 258 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 259 260 Args: 261 table: the `Table` expression instance or string representing the table. 262 column_mapping: a column mapping that describes the structure of the table. 263 dialect: the SQL dialect that will be used to parse `table` if it's a string. 264 normalize: whether to normalize identifiers according to the dialect of interest. 265 match_depth: whether to enforce that the table must match the schema's depth or not. 266 """ 267 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 268 269 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 270 raise SchemaError( 271 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 272 f"schema's nesting level: {self.depth()}." 273 ) 274 275 normalized_column_mapping = { 276 self._normalize_name(key, dialect=dialect, normalize=normalize): value 277 for key, value in ensure_column_mapping(column_mapping).items() 278 } 279 280 schema = self.find(normalized_table, raise_on_missing=False) 281 if schema and not normalized_column_mapping: 282 return 283 284 parts = self.table_parts(normalized_table) 285 286 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 287 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
289 def column_names( 290 self, 291 table: exp.Table | str, 292 only_visible: bool = False, 293 dialect: DialectType = None, 294 normalize: t.Optional[bool] = None, 295 ) -> t.List[str]: 296 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 297 298 schema = self.find(normalized_table) 299 if schema is None: 300 return [] 301 302 if not only_visible or not self.visible: 303 return list(schema) 304 305 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 306 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
308 def get_column_type( 309 self, 310 table: exp.Table | str, 311 column: exp.Column | str, 312 dialect: DialectType = None, 313 normalize: t.Optional[bool] = None, 314 ) -> exp.DataType: 315 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 316 317 normalized_column_name = self._normalize_name( 318 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 319 ) 320 321 table_schema = self.find(normalized_table, raise_on_missing=False) 322 if table_schema: 323 column_type = table_schema.get(normalized_column_name) 324 325 if isinstance(column_type, exp.DataType): 326 return column_type 327 elif isinstance(column_type, str): 328 return self._to_data_type(column_type, dialect=dialect) 329 330 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
332 def has_column( 333 self, 334 table: exp.Table | str, 335 column: exp.Column | str, 336 dialect: DialectType = None, 337 normalize: t.Optional[bool] = None, 338 ) -> bool: 339 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 340 341 normalized_column_name = self._normalize_name( 342 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 343 ) 344 345 table_schema = self.find(normalized_table, raise_on_missing=False) 346 return normalized_column_name in table_schema if table_schema else False
Returns whether or not column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
450def normalize_name( 451 identifier: str | exp.Identifier, 452 dialect: DialectType = None, 453 is_table: bool = False, 454 normalize: t.Optional[bool] = True, 455) -> str: 456 if isinstance(identifier, str): 457 identifier = exp.parse_identifier(identifier, dialect=dialect) 458 459 if not normalize: 460 return identifier.name 461 462 # This can be useful for normalize_identifier 463 identifier.meta["is_table"] = is_table 464 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
474def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 475 if mapping is None: 476 return {} 477 elif isinstance(mapping, dict): 478 return mapping 479 elif isinstance(mapping, str): 480 col_name_type_strs = [x.strip() for x in mapping.split(",")] 481 return { 482 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 483 for name_type_str in col_name_type_strs 484 } 485 # Check if mapping looks like a DataFrame StructType 486 elif hasattr(mapping, "simpleString"): 487 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 488 elif isinstance(mapping, list): 489 return {x.strip(): None for x in mapping} 490 491 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
494def flatten_schema( 495 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 496) -> t.List[t.List[str]]: 497 tables = [] 498 keys = keys or [] 499 500 for k, v in schema.items(): 501 if depth >= 2: 502 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 503 elif depth == 1: 504 tables.append(keys + [k]) 505 506 return tables
509def nested_get( 510 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 511) -> t.Optional[t.Any]: 512 """ 513 Get a value for a nested dictionary. 514 515 Args: 516 d: the dictionary to search. 517 *path: tuples of (name, key), where: 518 `key` is the key in the dictionary to get. 519 `name` is a string to use in the error if `key` isn't found. 520 521 Returns: 522 The value or None if it doesn't exist. 523 """ 524 for name, key in path: 525 d = d.get(key) # type: ignore 526 if d is None: 527 if raise_on_missing: 528 name = "table" if name == "this" else name 529 raise ValueError(f"Unknown {name}: {key}") 530 return None 531 532 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
535def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 536 """ 537 In-place set a value for a nested dictionary 538 539 Example: 540 >>> nested_set({}, ["top_key", "second_key"], "value") 541 {'top_key': {'second_key': 'value'}} 542 543 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 544 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 545 546 Args: 547 d: dictionary to update. 548 keys: the keys that makeup the path to `value`. 549 value: the value to set in the dictionary for the given key path. 550 551 Returns: 552 The (possibly) updated dictionary. 553 """ 554 if not keys: 555 return d 556 557 if len(keys) == 1: 558 d[keys[0]] = value 559 return d 560 561 subd = d 562 for key in keys[:-1]: 563 if key not in subd: 564 subd = subd.setdefault(key, {}) 565 else: 566 subd = subd[key] 567 568 subd[keys[-1]] = value 569 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.