sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot._typing import T 9from sqlglot.dialects.dialect import Dialect 10from sqlglot.errors import ParseError, SchemaError 11from sqlglot.helper import dict_depth 12from sqlglot.trie import TrieResult, in_trie, new_trie 13 14if t.TYPE_CHECKING: 15 from sqlglot.dataframe.sql.types import StructType 16 from sqlglot.dialects.dialect import DialectType 17 18 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 19 20TABLE_ARGS = ("this", "db", "catalog") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 dialect: DialectType 27 28 @abc.abstractmethod 29 def add_table( 30 self, 31 table: exp.Table | str, 32 column_mapping: t.Optional[ColumnMapping] = None, 33 dialect: DialectType = None, 34 normalize: t.Optional[bool] = None, 35 ) -> None: 36 """ 37 Register or update a table. Some implementing classes may require column information to also be provided. 38 39 Args: 40 table: the `Table` expression instance or string representing the table. 41 column_mapping: a column mapping that describes the structure of the table. 42 dialect: the SQL dialect that will be used to parse `table` if it's a string. 43 normalize: whether to normalize identifiers according to the dialect of interest. 44 """ 45 46 @abc.abstractmethod 47 def column_names( 48 self, 49 table: exp.Table | str, 50 only_visible: bool = False, 51 dialect: DialectType = None, 52 normalize: t.Optional[bool] = None, 53 ) -> t.List[str]: 54 """ 55 Get the column names for a table. 56 57 Args: 58 table: the `Table` expression instance. 59 only_visible: whether to include invisible columns. 60 dialect: the SQL dialect that will be used to parse `table` if it's a string. 61 normalize: whether to normalize identifiers according to the dialect of interest. 62 63 Returns: 64 The list of column names. 65 """ 66 67 @abc.abstractmethod 68 def get_column_type( 69 self, 70 table: exp.Table | str, 71 column: exp.Column, 72 dialect: DialectType = None, 73 normalize: t.Optional[bool] = None, 74 ) -> exp.DataType: 75 """ 76 Get the `sqlglot.exp.DataType` type of a column in the schema. 77 78 Args: 79 table: the source table. 80 column: the target column. 81 dialect: the SQL dialect that will be used to parse `table` if it's a string. 82 normalize: whether to normalize identifiers according to the dialect of interest. 83 84 Returns: 85 The resulting column type. 86 """ 87 88 @property 89 @abc.abstractmethod 90 def supported_table_args(self) -> t.Tuple[str, ...]: 91 """ 92 Table arguments this schema support, e.g. `("this", "db", "catalog")` 93 """ 94 95 @property 96 def empty(self) -> bool: 97 """Returns whether or not the schema is empty.""" 98 return True 99 100 101class AbstractMappingSchema(t.Generic[T]): 102 def __init__( 103 self, 104 mapping: t.Optional[t.Dict] = None, 105 ) -> None: 106 self.mapping = mapping or {} 107 self.mapping_trie = new_trie( 108 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 109 ) 110 self._supported_table_args: t.Tuple[str, ...] = tuple() 111 112 @property 113 def empty(self) -> bool: 114 return not self.mapping 115 116 def _depth(self) -> int: 117 return dict_depth(self.mapping) 118 119 @property 120 def supported_table_args(self) -> t.Tuple[str, ...]: 121 if not self._supported_table_args and self.mapping: 122 depth = self._depth() 123 124 if not depth: # None 125 self._supported_table_args = tuple() 126 elif 1 <= depth <= 3: 127 self._supported_table_args = TABLE_ARGS[:depth] 128 else: 129 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 130 131 return self._supported_table_args 132 133 def table_parts(self, table: exp.Table) -> t.List[str]: 134 if isinstance(table.this, exp.ReadCSV): 135 return [table.this.name] 136 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 137 138 def find( 139 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 140 ) -> t.Optional[T]: 141 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 142 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 143 144 if value == TrieResult.FAILED: 145 return None 146 147 if value == TrieResult.PREFIX: 148 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 149 150 if len(possibilities) == 1: 151 parts.extend(possibilities[0]) 152 else: 153 message = ", ".join(".".join(parts) for parts in possibilities) 154 if raise_on_missing: 155 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 156 return None 157 158 return self.nested_get(parts, raise_on_missing=raise_on_missing) 159 160 def nested_get( 161 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 162 ) -> t.Optional[t.Any]: 163 return nested_get( 164 d or self.mapping, 165 *zip(self.supported_table_args, reversed(parts)), 166 raise_on_missing=raise_on_missing, 167 ) 168 169 170class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 171 """ 172 Schema based on a nested mapping. 173 174 Args: 175 schema: Mapping in one of the following forms: 176 1. {table: {col: type}} 177 2. {db: {table: {col: type}}} 178 3. {catalog: {db: {table: {col: type}}}} 179 4. None - Tables will be added later 180 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 181 are assumed to be visible. The nesting should mirror that of the schema: 182 1. {table: set(*cols)}} 183 2. {db: {table: set(*cols)}}} 184 3. {catalog: {db: {table: set(*cols)}}}} 185 dialect: The dialect to be used for custom type mappings & parsing string arguments. 186 normalize: Whether to normalize identifier names according to the given dialect or not. 187 """ 188 189 def __init__( 190 self, 191 schema: t.Optional[t.Dict] = None, 192 visible: t.Optional[t.Dict] = None, 193 dialect: DialectType = None, 194 normalize: bool = True, 195 ) -> None: 196 self.dialect = dialect 197 self.visible = visible or {} 198 self.normalize = normalize 199 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 200 201 super().__init__(self._normalize(schema or {})) 202 203 @classmethod 204 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 205 return MappingSchema( 206 schema=mapping_schema.mapping, 207 visible=mapping_schema.visible, 208 dialect=mapping_schema.dialect, 209 normalize=mapping_schema.normalize, 210 ) 211 212 def copy(self, **kwargs) -> MappingSchema: 213 return MappingSchema( 214 **{ # type: ignore 215 "schema": self.mapping.copy(), 216 "visible": self.visible.copy(), 217 "dialect": self.dialect, 218 "normalize": self.normalize, 219 **kwargs, 220 } 221 ) 222 223 def add_table( 224 self, 225 table: exp.Table | str, 226 column_mapping: t.Optional[ColumnMapping] = None, 227 dialect: DialectType = None, 228 normalize: t.Optional[bool] = None, 229 ) -> None: 230 """ 231 Register or update a table. Updates are only performed if a new column mapping is provided. 232 233 Args: 234 table: the `Table` expression instance or string representing the table. 235 column_mapping: a column mapping that describes the structure of the table. 236 dialect: the SQL dialect that will be used to parse `table` if it's a string. 237 normalize: whether to normalize identifiers according to the dialect of interest. 238 """ 239 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 240 241 normalized_column_mapping = { 242 self._normalize_name(key, dialect=dialect, normalize=normalize): value 243 for key, value in ensure_column_mapping(column_mapping).items() 244 } 245 246 schema = self.find(normalized_table, raise_on_missing=False) 247 if schema and not normalized_column_mapping: 248 return 249 250 parts = self.table_parts(normalized_table) 251 252 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 253 new_trie([parts], self.mapping_trie) 254 255 def column_names( 256 self, 257 table: exp.Table | str, 258 only_visible: bool = False, 259 dialect: DialectType = None, 260 normalize: t.Optional[bool] = None, 261 ) -> t.List[str]: 262 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 263 264 schema = self.find(normalized_table) 265 if schema is None: 266 return [] 267 268 if not only_visible or not self.visible: 269 return list(schema) 270 271 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 272 return [col for col in schema if col in visible] 273 274 def get_column_type( 275 self, 276 table: exp.Table | str, 277 column: exp.Column, 278 dialect: DialectType = None, 279 normalize: t.Optional[bool] = None, 280 ) -> exp.DataType: 281 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 282 283 normalized_column_name = self._normalize_name( 284 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 285 ) 286 287 table_schema = self.find(normalized_table, raise_on_missing=False) 288 if table_schema: 289 column_type = table_schema.get(normalized_column_name) 290 291 if isinstance(column_type, exp.DataType): 292 return column_type 293 elif isinstance(column_type, str): 294 return self._to_data_type(column_type.upper(), dialect=dialect) 295 296 return exp.DataType.build("unknown") 297 298 def _normalize(self, schema: t.Dict) -> t.Dict: 299 """ 300 Normalizes all identifiers in the schema. 301 302 Args: 303 schema: the schema to normalize. 304 305 Returns: 306 The normalized schema mapping. 307 """ 308 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 309 310 normalized_mapping: t.Dict = {} 311 for keys in flattened_schema: 312 columns = nested_get(schema, *zip(keys, keys)) 313 assert columns is not None 314 315 normalized_keys = [ 316 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 317 ] 318 for column_name, column_type in columns.items(): 319 nested_set( 320 normalized_mapping, 321 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 322 column_type, 323 ) 324 325 return normalized_mapping 326 327 def _normalize_table( 328 self, 329 table: exp.Table | str, 330 dialect: DialectType = None, 331 normalize: t.Optional[bool] = None, 332 ) -> exp.Table: 333 normalized_table = exp.maybe_parse( 334 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 335 ) 336 337 for arg in TABLE_ARGS: 338 value = normalized_table.args.get(arg) 339 if isinstance(value, (str, exp.Identifier)): 340 normalized_table.set( 341 arg, 342 exp.to_identifier( 343 self._normalize_name( 344 value, dialect=dialect, is_table=True, normalize=normalize 345 ) 346 ), 347 ) 348 349 return normalized_table 350 351 def _normalize_name( 352 self, 353 name: str | exp.Identifier, 354 dialect: DialectType = None, 355 is_table: bool = False, 356 normalize: t.Optional[bool] = None, 357 ) -> str: 358 dialect = dialect or self.dialect 359 normalize = self.normalize if normalize is None else normalize 360 361 try: 362 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 363 except ParseError: 364 return name if isinstance(name, str) else name.name 365 366 name = identifier.name 367 if not normalize: 368 return name 369 370 # This can be useful for normalize_identifier 371 identifier.meta["is_table"] = is_table 372 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 373 374 def _depth(self) -> int: 375 # The columns themselves are a mapping, but we don't want to include those 376 return super()._depth() - 1 377 378 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 379 """ 380 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 381 382 Args: 383 schema_type: the type we want to convert. 384 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 385 386 Returns: 387 The resulting expression type. 388 """ 389 if schema_type not in self._type_mapping_cache: 390 dialect = dialect or self.dialect 391 392 try: 393 expression = exp.DataType.build(schema_type, dialect=dialect) 394 self._type_mapping_cache[schema_type] = expression 395 except AttributeError: 396 in_dialect = f" in dialect {dialect}" if dialect else "" 397 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 398 399 return self._type_mapping_cache[schema_type] 400 401 402def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 403 if isinstance(schema, Schema): 404 return schema 405 406 return MappingSchema(schema, **kwargs) 407 408 409def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 410 if mapping is None: 411 return {} 412 elif isinstance(mapping, dict): 413 return mapping 414 elif isinstance(mapping, str): 415 col_name_type_strs = [x.strip() for x in mapping.split(",")] 416 return { 417 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 418 for name_type_str in col_name_type_strs 419 } 420 # Check if mapping looks like a DataFrame StructType 421 elif hasattr(mapping, "simpleString"): 422 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 423 elif isinstance(mapping, list): 424 return {x.strip(): None for x in mapping} 425 426 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 427 428 429def flatten_schema( 430 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 431) -> t.List[t.List[str]]: 432 tables = [] 433 keys = keys or [] 434 435 for k, v in schema.items(): 436 if depth >= 2: 437 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 438 elif depth == 1: 439 tables.append(keys + [k]) 440 441 return tables 442 443 444def nested_get( 445 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 446) -> t.Optional[t.Any]: 447 """ 448 Get a value for a nested dictionary. 449 450 Args: 451 d: the dictionary to search. 452 *path: tuples of (name, key), where: 453 `key` is the key in the dictionary to get. 454 `name` is a string to use in the error if `key` isn't found. 455 456 Returns: 457 The value or None if it doesn't exist. 458 """ 459 for name, key in path: 460 d = d.get(key) # type: ignore 461 if d is None: 462 if raise_on_missing: 463 name = "table" if name == "this" else name 464 raise ValueError(f"Unknown {name}: {key}") 465 return None 466 467 return d 468 469 470def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 471 """ 472 In-place set a value for a nested dictionary 473 474 Example: 475 >>> nested_set({}, ["top_key", "second_key"], "value") 476 {'top_key': {'second_key': 'value'}} 477 478 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 479 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 480 481 Args: 482 d: dictionary to update. 483 keys: the keys that makeup the path to `value`. 484 value: the value to set in the dictionary for the given key path. 485 486 Returns: 487 The (possibly) updated dictionary. 488 """ 489 if not keys: 490 return d 491 492 if len(keys) == 1: 493 d[keys[0]] = value 494 return d 495 496 subd = d 497 for key in keys[:-1]: 498 if key not in subd: 499 subd = subd.setdefault(key, {}) 500 else: 501 subd = subd[key] 502 503 subd[keys[-1]] = value 504 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 dialect: DialectType 28 29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 40 Args: 41 table: the `Table` expression instance or string representing the table. 42 column_mapping: a column mapping that describes the structure of the table. 43 dialect: the SQL dialect that will be used to parse `table` if it's a string. 44 normalize: whether to normalize identifiers according to the dialect of interest. 45 """ 46 47 @abc.abstractmethod 48 def column_names( 49 self, 50 table: exp.Table | str, 51 only_visible: bool = False, 52 dialect: DialectType = None, 53 normalize: t.Optional[bool] = None, 54 ) -> t.List[str]: 55 """ 56 Get the column names for a table. 57 58 Args: 59 table: the `Table` expression instance. 60 only_visible: whether to include invisible columns. 61 dialect: the SQL dialect that will be used to parse `table` if it's a string. 62 normalize: whether to normalize identifiers according to the dialect of interest. 63 64 Returns: 65 The list of column names. 66 """ 67 68 @abc.abstractmethod 69 def get_column_type( 70 self, 71 table: exp.Table | str, 72 column: exp.Column, 73 dialect: DialectType = None, 74 normalize: t.Optional[bool] = None, 75 ) -> exp.DataType: 76 """ 77 Get the `sqlglot.exp.DataType` type of a column in the schema. 78 79 Args: 80 table: the source table. 81 column: the target column. 82 dialect: the SQL dialect that will be used to parse `table` if it's a string. 83 normalize: whether to normalize identifiers according to the dialect of interest. 84 85 Returns: 86 The resulting column type. 87 """ 88 89 @property 90 @abc.abstractmethod 91 def supported_table_args(self) -> t.Tuple[str, ...]: 92 """ 93 Table arguments this schema support, e.g. `("this", "db", "catalog")` 94 """ 95 96 @property 97 def empty(self) -> bool: 98 """Returns whether or not the schema is empty.""" 99 return True
Abstract base class for database schemas
29 @abc.abstractmethod 30 def add_table( 31 self, 32 table: exp.Table | str, 33 column_mapping: t.Optional[ColumnMapping] = None, 34 dialect: DialectType = None, 35 normalize: t.Optional[bool] = None, 36 ) -> None: 37 """ 38 Register or update a table. Some implementing classes may require column information to also be provided. 39 40 Args: 41 table: the `Table` expression instance or string representing the table. 42 column_mapping: a column mapping that describes the structure of the table. 43 dialect: the SQL dialect that will be used to parse `table` if it's a string. 44 normalize: whether to normalize identifiers according to the dialect of interest. 45 """
Register or update a table. Some implementing classes may require column information to also be provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
47 @abc.abstractmethod 48 def column_names( 49 self, 50 table: exp.Table | str, 51 only_visible: bool = False, 52 dialect: DialectType = None, 53 normalize: t.Optional[bool] = None, 54 ) -> t.List[str]: 55 """ 56 Get the column names for a table. 57 58 Args: 59 table: the `Table` expression instance. 60 only_visible: whether to include invisible columns. 61 dialect: the SQL dialect that will be used to parse `table` if it's a string. 62 normalize: whether to normalize identifiers according to the dialect of interest. 63 64 Returns: 65 The list of column names. 66 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
68 @abc.abstractmethod 69 def get_column_type( 70 self, 71 table: exp.Table | str, 72 column: exp.Column, 73 dialect: DialectType = None, 74 normalize: t.Optional[bool] = None, 75 ) -> exp.DataType: 76 """ 77 Get the `sqlglot.exp.DataType` type of a column in the schema. 78 79 Args: 80 table: the source table. 81 column: the target column. 82 dialect: the SQL dialect that will be used to parse `table` if it's a string. 83 normalize: whether to normalize identifiers according to the dialect of interest. 84 85 Returns: 86 The resulting column type. 87 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
102class AbstractMappingSchema(t.Generic[T]): 103 def __init__( 104 self, 105 mapping: t.Optional[t.Dict] = None, 106 ) -> None: 107 self.mapping = mapping or {} 108 self.mapping_trie = new_trie( 109 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 110 ) 111 self._supported_table_args: t.Tuple[str, ...] = tuple() 112 113 @property 114 def empty(self) -> bool: 115 return not self.mapping 116 117 def _depth(self) -> int: 118 return dict_depth(self.mapping) 119 120 @property 121 def supported_table_args(self) -> t.Tuple[str, ...]: 122 if not self._supported_table_args and self.mapping: 123 depth = self._depth() 124 125 if not depth: # None 126 self._supported_table_args = tuple() 127 elif 1 <= depth <= 3: 128 self._supported_table_args = TABLE_ARGS[:depth] 129 else: 130 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 131 132 return self._supported_table_args 133 134 def table_parts(self, table: exp.Table) -> t.List[str]: 135 if isinstance(table.this, exp.ReadCSV): 136 return [table.this.name] 137 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 138 139 def find( 140 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 141 ) -> t.Optional[T]: 142 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 143 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 144 145 if value == TrieResult.FAILED: 146 return None 147 148 if value == TrieResult.PREFIX: 149 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 150 151 if len(possibilities) == 1: 152 parts.extend(possibilities[0]) 153 else: 154 message = ", ".join(".".join(parts) for parts in possibilities) 155 if raise_on_missing: 156 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 157 return None 158 159 return self.nested_get(parts, raise_on_missing=raise_on_missing) 160 161 def nested_get( 162 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 163 ) -> t.Optional[t.Any]: 164 return nested_get( 165 d or self.mapping, 166 *zip(self.supported_table_args, reversed(parts)), 167 raise_on_missing=raise_on_missing, 168 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
139 def find( 140 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 141 ) -> t.Optional[T]: 142 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 143 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 144 145 if value == TrieResult.FAILED: 146 return None 147 148 if value == TrieResult.PREFIX: 149 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 150 151 if len(possibilities) == 1: 152 parts.extend(possibilities[0]) 153 else: 154 message = ", ".join(".".join(parts) for parts in possibilities) 155 if raise_on_missing: 156 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 157 return None 158 159 return self.nested_get(parts, raise_on_missing=raise_on_missing)
171class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 172 """ 173 Schema based on a nested mapping. 174 175 Args: 176 schema: Mapping in one of the following forms: 177 1. {table: {col: type}} 178 2. {db: {table: {col: type}}} 179 3. {catalog: {db: {table: {col: type}}}} 180 4. None - Tables will be added later 181 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 182 are assumed to be visible. The nesting should mirror that of the schema: 183 1. {table: set(*cols)}} 184 2. {db: {table: set(*cols)}}} 185 3. {catalog: {db: {table: set(*cols)}}}} 186 dialect: The dialect to be used for custom type mappings & parsing string arguments. 187 normalize: Whether to normalize identifier names according to the given dialect or not. 188 """ 189 190 def __init__( 191 self, 192 schema: t.Optional[t.Dict] = None, 193 visible: t.Optional[t.Dict] = None, 194 dialect: DialectType = None, 195 normalize: bool = True, 196 ) -> None: 197 self.dialect = dialect 198 self.visible = visible or {} 199 self.normalize = normalize 200 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 201 202 super().__init__(self._normalize(schema or {})) 203 204 @classmethod 205 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 206 return MappingSchema( 207 schema=mapping_schema.mapping, 208 visible=mapping_schema.visible, 209 dialect=mapping_schema.dialect, 210 normalize=mapping_schema.normalize, 211 ) 212 213 def copy(self, **kwargs) -> MappingSchema: 214 return MappingSchema( 215 **{ # type: ignore 216 "schema": self.mapping.copy(), 217 "visible": self.visible.copy(), 218 "dialect": self.dialect, 219 "normalize": self.normalize, 220 **kwargs, 221 } 222 ) 223 224 def add_table( 225 self, 226 table: exp.Table | str, 227 column_mapping: t.Optional[ColumnMapping] = None, 228 dialect: DialectType = None, 229 normalize: t.Optional[bool] = None, 230 ) -> None: 231 """ 232 Register or update a table. Updates are only performed if a new column mapping is provided. 233 234 Args: 235 table: the `Table` expression instance or string representing the table. 236 column_mapping: a column mapping that describes the structure of the table. 237 dialect: the SQL dialect that will be used to parse `table` if it's a string. 238 normalize: whether to normalize identifiers according to the dialect of interest. 239 """ 240 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 241 242 normalized_column_mapping = { 243 self._normalize_name(key, dialect=dialect, normalize=normalize): value 244 for key, value in ensure_column_mapping(column_mapping).items() 245 } 246 247 schema = self.find(normalized_table, raise_on_missing=False) 248 if schema and not normalized_column_mapping: 249 return 250 251 parts = self.table_parts(normalized_table) 252 253 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 254 new_trie([parts], self.mapping_trie) 255 256 def column_names( 257 self, 258 table: exp.Table | str, 259 only_visible: bool = False, 260 dialect: DialectType = None, 261 normalize: t.Optional[bool] = None, 262 ) -> t.List[str]: 263 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 264 265 schema = self.find(normalized_table) 266 if schema is None: 267 return [] 268 269 if not only_visible or not self.visible: 270 return list(schema) 271 272 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 273 return [col for col in schema if col in visible] 274 275 def get_column_type( 276 self, 277 table: exp.Table | str, 278 column: exp.Column, 279 dialect: DialectType = None, 280 normalize: t.Optional[bool] = None, 281 ) -> exp.DataType: 282 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 283 284 normalized_column_name = self._normalize_name( 285 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 286 ) 287 288 table_schema = self.find(normalized_table, raise_on_missing=False) 289 if table_schema: 290 column_type = table_schema.get(normalized_column_name) 291 292 if isinstance(column_type, exp.DataType): 293 return column_type 294 elif isinstance(column_type, str): 295 return self._to_data_type(column_type.upper(), dialect=dialect) 296 297 return exp.DataType.build("unknown") 298 299 def _normalize(self, schema: t.Dict) -> t.Dict: 300 """ 301 Normalizes all identifiers in the schema. 302 303 Args: 304 schema: the schema to normalize. 305 306 Returns: 307 The normalized schema mapping. 308 """ 309 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 310 311 normalized_mapping: t.Dict = {} 312 for keys in flattened_schema: 313 columns = nested_get(schema, *zip(keys, keys)) 314 assert columns is not None 315 316 normalized_keys = [ 317 self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys 318 ] 319 for column_name, column_type in columns.items(): 320 nested_set( 321 normalized_mapping, 322 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 323 column_type, 324 ) 325 326 return normalized_mapping 327 328 def _normalize_table( 329 self, 330 table: exp.Table | str, 331 dialect: DialectType = None, 332 normalize: t.Optional[bool] = None, 333 ) -> exp.Table: 334 normalized_table = exp.maybe_parse( 335 table, into=exp.Table, dialect=dialect or self.dialect, copy=True 336 ) 337 338 for arg in TABLE_ARGS: 339 value = normalized_table.args.get(arg) 340 if isinstance(value, (str, exp.Identifier)): 341 normalized_table.set( 342 arg, 343 exp.to_identifier( 344 self._normalize_name( 345 value, dialect=dialect, is_table=True, normalize=normalize 346 ) 347 ), 348 ) 349 350 return normalized_table 351 352 def _normalize_name( 353 self, 354 name: str | exp.Identifier, 355 dialect: DialectType = None, 356 is_table: bool = False, 357 normalize: t.Optional[bool] = None, 358 ) -> str: 359 dialect = dialect or self.dialect 360 normalize = self.normalize if normalize is None else normalize 361 362 try: 363 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 364 except ParseError: 365 return name if isinstance(name, str) else name.name 366 367 name = identifier.name 368 if not normalize: 369 return name 370 371 # This can be useful for normalize_identifier 372 identifier.meta["is_table"] = is_table 373 return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name 374 375 def _depth(self) -> int: 376 # The columns themselves are a mapping, but we don't want to include those 377 return super()._depth() - 1 378 379 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 380 """ 381 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 382 383 Args: 384 schema_type: the type we want to convert. 385 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 386 387 Returns: 388 The resulting expression type. 389 """ 390 if schema_type not in self._type_mapping_cache: 391 dialect = dialect or self.dialect 392 393 try: 394 expression = exp.DataType.build(schema_type, dialect=dialect) 395 self._type_mapping_cache[schema_type] = expression 396 except AttributeError: 397 in_dialect = f" in dialect {dialect}" if dialect else "" 398 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 399 400 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
190 def __init__( 191 self, 192 schema: t.Optional[t.Dict] = None, 193 visible: t.Optional[t.Dict] = None, 194 dialect: DialectType = None, 195 normalize: bool = True, 196 ) -> None: 197 self.dialect = dialect 198 self.visible = visible or {} 199 self.normalize = normalize 200 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 201 202 super().__init__(self._normalize(schema or {}))
410def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 411 if mapping is None: 412 return {} 413 elif isinstance(mapping, dict): 414 return mapping 415 elif isinstance(mapping, str): 416 col_name_type_strs = [x.strip() for x in mapping.split(",")] 417 return { 418 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 419 for name_type_str in col_name_type_strs 420 } 421 # Check if mapping looks like a DataFrame StructType 422 elif hasattr(mapping, "simpleString"): 423 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 424 elif isinstance(mapping, list): 425 return {x.strip(): None for x in mapping} 426 427 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
430def flatten_schema( 431 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 432) -> t.List[t.List[str]]: 433 tables = [] 434 keys = keys or [] 435 436 for k, v in schema.items(): 437 if depth >= 2: 438 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 439 elif depth == 1: 440 tables.append(keys + [k]) 441 442 return tables
445def nested_get( 446 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 447) -> t.Optional[t.Any]: 448 """ 449 Get a value for a nested dictionary. 450 451 Args: 452 d: the dictionary to search. 453 *path: tuples of (name, key), where: 454 `key` is the key in the dictionary to get. 455 `name` is a string to use in the error if `key` isn't found. 456 457 Returns: 458 The value or None if it doesn't exist. 459 """ 460 for name, key in path: 461 d = d.get(key) # type: ignore 462 if d is None: 463 if raise_on_missing: 464 name = "table" if name == "this" else name 465 raise ValueError(f"Unknown {name}: {key}") 466 return None 467 468 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
471def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 472 """ 473 In-place set a value for a nested dictionary 474 475 Example: 476 >>> nested_set({}, ["top_key", "second_key"], "value") 477 {'top_key': {'second_key': 'value'}} 478 479 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 480 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 481 482 Args: 483 d: dictionary to update. 484 keys: the keys that makeup the path to `value`. 485 value: the value to set in the dictionary for the given key path. 486 487 Returns: 488 The (possibly) updated dictionary. 489 """ 490 if not keys: 491 return d 492 493 if len(keys) == 1: 494 d[keys[0]] = value 495 return d 496 497 subd = d 498 for key in keys[:-1]: 499 if key not in subd: 500 subd = subd.setdefault(key, {}) 501 else: 502 subd = subd[key] 503 504 subd[keys[-1]] = value 505 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.