sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6from sqlglot import expressions as exp 7from sqlglot.dialects.dialect import Dialect 8from sqlglot.errors import SchemaError 9from sqlglot.helper import dict_depth, first 10from sqlglot.trie import TrieResult, in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15 ColumnMapping = t.Union[t.Dict, str, t.List] 16 17 18class Schema(abc.ABC): 19 """Abstract base class for database schemas""" 20 21 dialect: DialectType 22 23 @abc.abstractmethod 24 def add_table( 25 self, 26 table: exp.Table | str, 27 column_mapping: t.Optional[ColumnMapping] = None, 28 dialect: DialectType = None, 29 normalize: t.Optional[bool] = None, 30 match_depth: bool = True, 31 ) -> None: 32 """ 33 Register or update a table. Some implementing classes may require column information to also be provided. 34 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 35 36 Args: 37 table: the `Table` expression instance or string representing the table. 38 column_mapping: a column mapping that describes the structure of the table. 39 dialect: the SQL dialect that will be used to parse `table` if it's a string. 40 normalize: whether to normalize identifiers according to the dialect of interest. 41 match_depth: whether to enforce that the table must match the schema's depth or not. 42 """ 43 44 @abc.abstractmethod 45 def column_names( 46 self, 47 table: exp.Table | str, 48 only_visible: bool = False, 49 dialect: DialectType = None, 50 normalize: t.Optional[bool] = None, 51 ) -> t.Sequence[str]: 52 """ 53 Get the column names for a table. 54 55 Args: 56 table: the `Table` expression instance. 57 only_visible: whether to include invisible columns. 58 dialect: the SQL dialect that will be used to parse `table` if it's a string. 59 normalize: whether to normalize identifiers according to the dialect of interest. 60 61 Returns: 62 The sequence of column names. 63 """ 64 65 @abc.abstractmethod 66 def get_column_type( 67 self, 68 table: exp.Table | str, 69 column: exp.Column | str, 70 dialect: DialectType = None, 71 normalize: t.Optional[bool] = None, 72 ) -> exp.DataType: 73 """ 74 Get the `sqlglot.exp.DataType` type of a column in the schema. 75 76 Args: 77 table: the source table. 78 column: the target column. 79 dialect: the SQL dialect that will be used to parse `table` if it's a string. 80 normalize: whether to normalize identifiers according to the dialect of interest. 81 82 Returns: 83 The resulting column type. 84 """ 85 86 def has_column( 87 self, 88 table: exp.Table | str, 89 column: exp.Column | str, 90 dialect: DialectType = None, 91 normalize: t.Optional[bool] = None, 92 ) -> bool: 93 """ 94 Returns whether `column` appears in `table`'s schema. 95 96 Args: 97 table: the source table. 98 column: the target column. 99 dialect: the SQL dialect that will be used to parse `table` if it's a string. 100 normalize: whether to normalize identifiers according to the dialect of interest. 101 102 Returns: 103 True if the column appears in the schema, False otherwise. 104 """ 105 name = column if isinstance(column, str) else column.name 106 return name in self.column_names(table, dialect=dialect, normalize=normalize) 107 108 @property 109 @abc.abstractmethod 110 def supported_table_args(self) -> t.Tuple[str, ...]: 111 """ 112 Table arguments this schema support, e.g. `("this", "db", "catalog")` 113 """ 114 115 @property 116 def empty(self) -> bool: 117 """Returns whether the schema is empty.""" 118 return True 119 120 121class AbstractMappingSchema: 122 def __init__( 123 self, 124 mapping: t.Optional[t.Dict] = None, 125 ) -> None: 126 self.mapping = mapping or {} 127 self.mapping_trie = new_trie( 128 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 129 ) 130 self._supported_table_args: t.Tuple[str, ...] = tuple() 131 132 @property 133 def empty(self) -> bool: 134 return not self.mapping 135 136 def depth(self) -> int: 137 return dict_depth(self.mapping) 138 139 @property 140 def supported_table_args(self) -> t.Tuple[str, ...]: 141 if not self._supported_table_args and self.mapping: 142 depth = self.depth() 143 144 if not depth: # None 145 self._supported_table_args = tuple() 146 elif 1 <= depth <= 3: 147 self._supported_table_args = exp.TABLE_PARTS[:depth] 148 else: 149 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 150 151 return self._supported_table_args 152 153 def table_parts(self, table: exp.Table) -> t.List[str]: 154 if isinstance(table.this, exp.ReadCSV): 155 return [table.this.name] 156 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 157 158 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 159 """ 160 Returns the schema of a given table. 161 162 Args: 163 table: the target table. 164 raise_on_missing: whether to raise in case the schema is not found. 165 166 Returns: 167 The schema of the target table. 168 """ 169 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 170 value, trie = in_trie(self.mapping_trie, parts) 171 172 if value == TrieResult.FAILED: 173 return None 174 175 if value == TrieResult.PREFIX: 176 possibilities = flatten_schema(trie) 177 178 if len(possibilities) == 1: 179 parts.extend(possibilities[0]) 180 else: 181 message = ", ".join(".".join(parts) for parts in possibilities) 182 if raise_on_missing: 183 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 184 return None 185 186 return self.nested_get(parts, raise_on_missing=raise_on_missing) 187 188 def nested_get( 189 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 190 ) -> t.Optional[t.Any]: 191 return nested_get( 192 d or self.mapping, 193 *zip(self.supported_table_args, reversed(parts)), 194 raise_on_missing=raise_on_missing, 195 ) 196 197 198class MappingSchema(AbstractMappingSchema, Schema): 199 """ 200 Schema based on a nested mapping. 201 202 Args: 203 schema: Mapping in one of the following forms: 204 1. {table: {col: type}} 205 2. {db: {table: {col: type}}} 206 3. {catalog: {db: {table: {col: type}}}} 207 4. None - Tables will be added later 208 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 209 are assumed to be visible. The nesting should mirror that of the schema: 210 1. {table: set(*cols)}} 211 2. {db: {table: set(*cols)}}} 212 3. {catalog: {db: {table: set(*cols)}}}} 213 dialect: The dialect to be used for custom type mappings & parsing string arguments. 214 normalize: Whether to normalize identifier names according to the given dialect or not. 215 """ 216 217 def __init__( 218 self, 219 schema: t.Optional[t.Dict] = None, 220 visible: t.Optional[t.Dict] = None, 221 dialect: DialectType = None, 222 normalize: bool = True, 223 ) -> None: 224 self.dialect = dialect 225 self.visible = {} if visible is None else visible 226 self.normalize = normalize 227 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 228 self._depth = 0 229 schema = {} if schema is None else schema 230 231 super().__init__(self._normalize(schema) if self.normalize else schema) 232 233 @classmethod 234 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 235 return MappingSchema( 236 schema=mapping_schema.mapping, 237 visible=mapping_schema.visible, 238 dialect=mapping_schema.dialect, 239 normalize=mapping_schema.normalize, 240 ) 241 242 def copy(self, **kwargs) -> MappingSchema: 243 return MappingSchema( 244 **{ # type: ignore 245 "schema": self.mapping.copy(), 246 "visible": self.visible.copy(), 247 "dialect": self.dialect, 248 "normalize": self.normalize, 249 **kwargs, 250 } 251 ) 252 253 def add_table( 254 self, 255 table: exp.Table | str, 256 column_mapping: t.Optional[ColumnMapping] = None, 257 dialect: DialectType = None, 258 normalize: t.Optional[bool] = None, 259 match_depth: bool = True, 260 ) -> None: 261 """ 262 Register or update a table. Updates are only performed if a new column mapping is provided. 263 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 264 265 Args: 266 table: the `Table` expression instance or string representing the table. 267 column_mapping: a column mapping that describes the structure of the table. 268 dialect: the SQL dialect that will be used to parse `table` if it's a string. 269 normalize: whether to normalize identifiers according to the dialect of interest. 270 match_depth: whether to enforce that the table must match the schema's depth or not. 271 """ 272 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 273 274 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 275 raise SchemaError( 276 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 277 f"schema's nesting level: {self.depth()}." 278 ) 279 280 normalized_column_mapping = { 281 self._normalize_name(key, dialect=dialect, normalize=normalize): value 282 for key, value in ensure_column_mapping(column_mapping).items() 283 } 284 285 schema = self.find(normalized_table, raise_on_missing=False) 286 if schema and not normalized_column_mapping: 287 return 288 289 parts = self.table_parts(normalized_table) 290 291 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 292 new_trie([parts], self.mapping_trie) 293 294 def column_names( 295 self, 296 table: exp.Table | str, 297 only_visible: bool = False, 298 dialect: DialectType = None, 299 normalize: t.Optional[bool] = None, 300 ) -> t.List[str]: 301 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 302 303 schema = self.find(normalized_table) 304 if schema is None: 305 return [] 306 307 if not only_visible or not self.visible: 308 return list(schema) 309 310 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 311 return [col for col in schema if col in visible] 312 313 def get_column_type( 314 self, 315 table: exp.Table | str, 316 column: exp.Column | str, 317 dialect: DialectType = None, 318 normalize: t.Optional[bool] = None, 319 ) -> exp.DataType: 320 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 321 322 normalized_column_name = self._normalize_name( 323 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 324 ) 325 326 table_schema = self.find(normalized_table, raise_on_missing=False) 327 if table_schema: 328 column_type = table_schema.get(normalized_column_name) 329 330 if isinstance(column_type, exp.DataType): 331 return column_type 332 elif isinstance(column_type, str): 333 return self._to_data_type(column_type, dialect=dialect) 334 335 return exp.DataType.build("unknown") 336 337 def has_column( 338 self, 339 table: exp.Table | str, 340 column: exp.Column | str, 341 dialect: DialectType = None, 342 normalize: t.Optional[bool] = None, 343 ) -> bool: 344 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 345 346 normalized_column_name = self._normalize_name( 347 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 348 ) 349 350 table_schema = self.find(normalized_table, raise_on_missing=False) 351 return normalized_column_name in table_schema if table_schema else False 352 353 def _normalize(self, schema: t.Dict) -> t.Dict: 354 """ 355 Normalizes all identifiers in the schema. 356 357 Args: 358 schema: the schema to normalize. 359 360 Returns: 361 The normalized schema mapping. 362 """ 363 normalized_mapping: t.Dict = {} 364 flattened_schema = flatten_schema(schema) 365 error_msg = "Table {} must match the schema's nesting level: {}." 366 367 for keys in flattened_schema: 368 columns = nested_get(schema, *zip(keys, keys)) 369 370 if not isinstance(columns, dict): 371 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 372 if isinstance(first(columns.values()), dict): 373 raise SchemaError( 374 error_msg.format( 375 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 376 ), 377 ) 378 379 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 380 for column_name, column_type in columns.items(): 381 nested_set( 382 normalized_mapping, 383 normalized_keys + [self._normalize_name(column_name)], 384 column_type, 385 ) 386 387 return normalized_mapping 388 389 def _normalize_table( 390 self, 391 table: exp.Table | str, 392 dialect: DialectType = None, 393 normalize: t.Optional[bool] = None, 394 ) -> exp.Table: 395 dialect = dialect or self.dialect 396 normalize = self.normalize if normalize is None else normalize 397 398 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 399 400 if normalize: 401 for arg in exp.TABLE_PARTS: 402 value = normalized_table.args.get(arg) 403 if isinstance(value, exp.Identifier): 404 normalized_table.set( 405 arg, 406 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 407 ) 408 409 return normalized_table 410 411 def _normalize_name( 412 self, 413 name: str | exp.Identifier, 414 dialect: DialectType = None, 415 is_table: bool = False, 416 normalize: t.Optional[bool] = None, 417 ) -> str: 418 return normalize_name( 419 name, 420 dialect=dialect or self.dialect, 421 is_table=is_table, 422 normalize=self.normalize if normalize is None else normalize, 423 ).name 424 425 def depth(self) -> int: 426 if not self.empty and not self._depth: 427 # The columns themselves are a mapping, but we don't want to include those 428 self._depth = super().depth() - 1 429 return self._depth 430 431 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 432 """ 433 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 434 435 Args: 436 schema_type: the type we want to convert. 437 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 438 439 Returns: 440 The resulting expression type. 441 """ 442 if schema_type not in self._type_mapping_cache: 443 dialect = dialect or self.dialect 444 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 445 446 try: 447 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 448 self._type_mapping_cache[schema_type] = expression 449 except AttributeError: 450 in_dialect = f" in dialect {dialect}" if dialect else "" 451 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 452 453 return self._type_mapping_cache[schema_type] 454 455 456def normalize_name( 457 identifier: str | exp.Identifier, 458 dialect: DialectType = None, 459 is_table: bool = False, 460 normalize: t.Optional[bool] = True, 461) -> exp.Identifier: 462 if isinstance(identifier, str): 463 identifier = exp.parse_identifier(identifier, dialect=dialect) 464 465 if not normalize: 466 return identifier 467 468 # this is used for normalize_identifier, bigquery has special rules pertaining tables 469 identifier.meta["is_table"] = is_table 470 return Dialect.get_or_raise(dialect).normalize_identifier(identifier) 471 472 473def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: 474 if isinstance(schema, Schema): 475 return schema 476 477 return MappingSchema(schema, **kwargs) 478 479 480def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 481 if mapping is None: 482 return {} 483 elif isinstance(mapping, dict): 484 return mapping 485 elif isinstance(mapping, str): 486 col_name_type_strs = [x.strip() for x in mapping.split(",")] 487 return { 488 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 489 for name_type_str in col_name_type_strs 490 } 491 elif isinstance(mapping, list): 492 return {x.strip(): None for x in mapping} 493 494 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 495 496 497def flatten_schema( 498 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 499) -> t.List[t.List[str]]: 500 tables = [] 501 keys = keys or [] 502 depth = dict_depth(schema) - 1 if depth is None else depth 503 504 for k, v in schema.items(): 505 if depth == 1 or not isinstance(v, dict): 506 tables.append(keys + [k]) 507 elif depth >= 2: 508 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 509 510 return tables 511 512 513def nested_get( 514 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 515) -> t.Optional[t.Any]: 516 """ 517 Get a value for a nested dictionary. 518 519 Args: 520 d: the dictionary to search. 521 *path: tuples of (name, key), where: 522 `key` is the key in the dictionary to get. 523 `name` is a string to use in the error if `key` isn't found. 524 525 Returns: 526 The value or None if it doesn't exist. 527 """ 528 for name, key in path: 529 d = d.get(key) # type: ignore 530 if d is None: 531 if raise_on_missing: 532 name = "table" if name == "this" else name 533 raise ValueError(f"Unknown {name}: {key}") 534 return None 535 536 return d 537 538 539def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 540 """ 541 In-place set a value for a nested dictionary 542 543 Example: 544 >>> nested_set({}, ["top_key", "second_key"], "value") 545 {'top_key': {'second_key': 'value'}} 546 547 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 548 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 549 550 Args: 551 d: dictionary to update. 552 keys: the keys that makeup the path to `value`. 553 value: the value to set in the dictionary for the given key path. 554 555 Returns: 556 The (possibly) updated dictionary. 557 """ 558 if not keys: 559 return d 560 561 if len(keys) == 1: 562 d[keys[0]] = value 563 return d 564 565 subd = d 566 for key in keys[:-1]: 567 if key not in subd: 568 subd = subd.setdefault(key, {}) 569 else: 570 subd = subd[key] 571 572 subd[keys[-1]] = value 573 return d
19class Schema(abc.ABC): 20 """Abstract base class for database schemas""" 21 22 dialect: DialectType 23 24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """ 44 45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """ 65 66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """ 86 87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize) 108 109 @property 110 @abc.abstractmethod 111 def supported_table_args(self) -> t.Tuple[str, ...]: 112 """ 113 Table arguments this schema support, e.g. `("this", "db", "catalog")` 114 """ 115 116 @property 117 def empty(self) -> bool: 118 """Returns whether the schema is empty.""" 119 return True
Abstract base class for database schemas
24 @abc.abstractmethod 25 def add_table( 26 self, 27 table: exp.Table | str, 28 column_mapping: t.Optional[ColumnMapping] = None, 29 dialect: DialectType = None, 30 normalize: t.Optional[bool] = None, 31 match_depth: bool = True, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 normalize: whether to normalize identifiers according to the dialect of interest. 42 match_depth: whether to enforce that the table must match the schema's depth or not. 43 """
Register or update a table. Some implementing classes may require column information to also be provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
45 @abc.abstractmethod 46 def column_names( 47 self, 48 table: exp.Table | str, 49 only_visible: bool = False, 50 dialect: DialectType = None, 51 normalize: t.Optional[bool] = None, 52 ) -> t.Sequence[str]: 53 """ 54 Get the column names for a table. 55 56 Args: 57 table: the `Table` expression instance. 58 only_visible: whether to include invisible columns. 59 dialect: the SQL dialect that will be used to parse `table` if it's a string. 60 normalize: whether to normalize identifiers according to the dialect of interest. 61 62 Returns: 63 The sequence of column names. 64 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
66 @abc.abstractmethod 67 def get_column_type( 68 self, 69 table: exp.Table | str, 70 column: exp.Column | str, 71 dialect: DialectType = None, 72 normalize: t.Optional[bool] = None, 73 ) -> exp.DataType: 74 """ 75 Get the `sqlglot.exp.DataType` type of a column in the schema. 76 77 Args: 78 table: the source table. 79 column: the target column. 80 dialect: the SQL dialect that will be used to parse `table` if it's a string. 81 normalize: whether to normalize identifiers according to the dialect of interest. 82 83 Returns: 84 The resulting column type. 85 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
87 def has_column( 88 self, 89 table: exp.Table | str, 90 column: exp.Column | str, 91 dialect: DialectType = None, 92 normalize: t.Optional[bool] = None, 93 ) -> bool: 94 """ 95 Returns whether `column` appears in `table`'s schema. 96 97 Args: 98 table: the source table. 99 column: the target column. 100 dialect: the SQL dialect that will be used to parse `table` if it's a string. 101 normalize: whether to normalize identifiers according to the dialect of interest. 102 103 Returns: 104 True if the column appears in the schema, False otherwise. 105 """ 106 name = column if isinstance(column, str) else column.name 107 return name in self.column_names(table, dialect=dialect, normalize=normalize)
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
122class AbstractMappingSchema: 123 def __init__( 124 self, 125 mapping: t.Optional[t.Dict] = None, 126 ) -> None: 127 self.mapping = mapping or {} 128 self.mapping_trie = new_trie( 129 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) 130 ) 131 self._supported_table_args: t.Tuple[str, ...] = tuple() 132 133 @property 134 def empty(self) -> bool: 135 return not self.mapping 136 137 def depth(self) -> int: 138 return dict_depth(self.mapping) 139 140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args 153 154 def table_parts(self, table: exp.Table) -> t.List[str]: 155 if isinstance(table.this, exp.ReadCSV): 156 return [table.this.name] 157 return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] 158 159 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 160 """ 161 Returns the schema of a given table. 162 163 Args: 164 table: the target table. 165 raise_on_missing: whether to raise in case the schema is not found. 166 167 Returns: 168 The schema of the target table. 169 """ 170 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 171 value, trie = in_trie(self.mapping_trie, parts) 172 173 if value == TrieResult.FAILED: 174 return None 175 176 if value == TrieResult.PREFIX: 177 possibilities = flatten_schema(trie) 178 179 if len(possibilities) == 1: 180 parts.extend(possibilities[0]) 181 else: 182 message = ", ".join(".".join(parts) for parts in possibilities) 183 if raise_on_missing: 184 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 185 return None 186 187 return self.nested_get(parts, raise_on_missing=raise_on_missing) 188 189 def nested_get( 190 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 191 ) -> t.Optional[t.Any]: 192 return nested_get( 193 d or self.mapping, 194 *zip(self.supported_table_args, reversed(parts)), 195 raise_on_missing=raise_on_missing, 196 )
140 @property 141 def supported_table_args(self) -> t.Tuple[str, ...]: 142 if not self._supported_table_args and self.mapping: 143 depth = self.depth() 144 145 if not depth: # None 146 self._supported_table_args = tuple() 147 elif 1 <= depth <= 3: 148 self._supported_table_args = exp.TABLE_PARTS[:depth] 149 else: 150 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 151 152 return self._supported_table_args
159 def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: 160 """ 161 Returns the schema of a given table. 162 163 Args: 164 table: the target table. 165 raise_on_missing: whether to raise in case the schema is not found. 166 167 Returns: 168 The schema of the target table. 169 """ 170 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 171 value, trie = in_trie(self.mapping_trie, parts) 172 173 if value == TrieResult.FAILED: 174 return None 175 176 if value == TrieResult.PREFIX: 177 possibilities = flatten_schema(trie) 178 179 if len(possibilities) == 1: 180 parts.extend(possibilities[0]) 181 else: 182 message = ", ".join(".".join(parts) for parts in possibilities) 183 if raise_on_missing: 184 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 185 return None 186 187 return self.nested_get(parts, raise_on_missing=raise_on_missing)
Returns the schema of a given table.
Arguments:
- table: the target table.
- raise_on_missing: whether to raise in case the schema is not found.
Returns:
The schema of the target table.
199class MappingSchema(AbstractMappingSchema, Schema): 200 """ 201 Schema based on a nested mapping. 202 203 Args: 204 schema: Mapping in one of the following forms: 205 1. {table: {col: type}} 206 2. {db: {table: {col: type}}} 207 3. {catalog: {db: {table: {col: type}}}} 208 4. None - Tables will be added later 209 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 210 are assumed to be visible. The nesting should mirror that of the schema: 211 1. {table: set(*cols)}} 212 2. {db: {table: set(*cols)}}} 213 3. {catalog: {db: {table: set(*cols)}}}} 214 dialect: The dialect to be used for custom type mappings & parsing string arguments. 215 normalize: Whether to normalize identifier names according to the given dialect or not. 216 """ 217 218 def __init__( 219 self, 220 schema: t.Optional[t.Dict] = None, 221 visible: t.Optional[t.Dict] = None, 222 dialect: DialectType = None, 223 normalize: bool = True, 224 ) -> None: 225 self.dialect = dialect 226 self.visible = {} if visible is None else visible 227 self.normalize = normalize 228 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 229 self._depth = 0 230 schema = {} if schema is None else schema 231 232 super().__init__(self._normalize(schema) if self.normalize else schema) 233 234 @classmethod 235 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 236 return MappingSchema( 237 schema=mapping_schema.mapping, 238 visible=mapping_schema.visible, 239 dialect=mapping_schema.dialect, 240 normalize=mapping_schema.normalize, 241 ) 242 243 def copy(self, **kwargs) -> MappingSchema: 244 return MappingSchema( 245 **{ # type: ignore 246 "schema": self.mapping.copy(), 247 "visible": self.visible.copy(), 248 "dialect": self.dialect, 249 "normalize": self.normalize, 250 **kwargs, 251 } 252 ) 253 254 def add_table( 255 self, 256 table: exp.Table | str, 257 column_mapping: t.Optional[ColumnMapping] = None, 258 dialect: DialectType = None, 259 normalize: t.Optional[bool] = None, 260 match_depth: bool = True, 261 ) -> None: 262 """ 263 Register or update a table. Updates are only performed if a new column mapping is provided. 264 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 265 266 Args: 267 table: the `Table` expression instance or string representing the table. 268 column_mapping: a column mapping that describes the structure of the table. 269 dialect: the SQL dialect that will be used to parse `table` if it's a string. 270 normalize: whether to normalize identifiers according to the dialect of interest. 271 match_depth: whether to enforce that the table must match the schema's depth or not. 272 """ 273 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 274 275 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 276 raise SchemaError( 277 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 278 f"schema's nesting level: {self.depth()}." 279 ) 280 281 normalized_column_mapping = { 282 self._normalize_name(key, dialect=dialect, normalize=normalize): value 283 for key, value in ensure_column_mapping(column_mapping).items() 284 } 285 286 schema = self.find(normalized_table, raise_on_missing=False) 287 if schema and not normalized_column_mapping: 288 return 289 290 parts = self.table_parts(normalized_table) 291 292 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 293 new_trie([parts], self.mapping_trie) 294 295 def column_names( 296 self, 297 table: exp.Table | str, 298 only_visible: bool = False, 299 dialect: DialectType = None, 300 normalize: t.Optional[bool] = None, 301 ) -> t.List[str]: 302 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 303 304 schema = self.find(normalized_table) 305 if schema is None: 306 return [] 307 308 if not only_visible or not self.visible: 309 return list(schema) 310 311 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 312 return [col for col in schema if col in visible] 313 314 def get_column_type( 315 self, 316 table: exp.Table | str, 317 column: exp.Column | str, 318 dialect: DialectType = None, 319 normalize: t.Optional[bool] = None, 320 ) -> exp.DataType: 321 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 322 323 normalized_column_name = self._normalize_name( 324 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 325 ) 326 327 table_schema = self.find(normalized_table, raise_on_missing=False) 328 if table_schema: 329 column_type = table_schema.get(normalized_column_name) 330 331 if isinstance(column_type, exp.DataType): 332 return column_type 333 elif isinstance(column_type, str): 334 return self._to_data_type(column_type, dialect=dialect) 335 336 return exp.DataType.build("unknown") 337 338 def has_column( 339 self, 340 table: exp.Table | str, 341 column: exp.Column | str, 342 dialect: DialectType = None, 343 normalize: t.Optional[bool] = None, 344 ) -> bool: 345 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 346 347 normalized_column_name = self._normalize_name( 348 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 349 ) 350 351 table_schema = self.find(normalized_table, raise_on_missing=False) 352 return normalized_column_name in table_schema if table_schema else False 353 354 def _normalize(self, schema: t.Dict) -> t.Dict: 355 """ 356 Normalizes all identifiers in the schema. 357 358 Args: 359 schema: the schema to normalize. 360 361 Returns: 362 The normalized schema mapping. 363 """ 364 normalized_mapping: t.Dict = {} 365 flattened_schema = flatten_schema(schema) 366 error_msg = "Table {} must match the schema's nesting level: {}." 367 368 for keys in flattened_schema: 369 columns = nested_get(schema, *zip(keys, keys)) 370 371 if not isinstance(columns, dict): 372 raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) 373 if isinstance(first(columns.values()), dict): 374 raise SchemaError( 375 error_msg.format( 376 ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) 377 ), 378 ) 379 380 normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] 381 for column_name, column_type in columns.items(): 382 nested_set( 383 normalized_mapping, 384 normalized_keys + [self._normalize_name(column_name)], 385 column_type, 386 ) 387 388 return normalized_mapping 389 390 def _normalize_table( 391 self, 392 table: exp.Table | str, 393 dialect: DialectType = None, 394 normalize: t.Optional[bool] = None, 395 ) -> exp.Table: 396 dialect = dialect or self.dialect 397 normalize = self.normalize if normalize is None else normalize 398 399 normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) 400 401 if normalize: 402 for arg in exp.TABLE_PARTS: 403 value = normalized_table.args.get(arg) 404 if isinstance(value, exp.Identifier): 405 normalized_table.set( 406 arg, 407 normalize_name(value, dialect=dialect, is_table=True, normalize=normalize), 408 ) 409 410 return normalized_table 411 412 def _normalize_name( 413 self, 414 name: str | exp.Identifier, 415 dialect: DialectType = None, 416 is_table: bool = False, 417 normalize: t.Optional[bool] = None, 418 ) -> str: 419 return normalize_name( 420 name, 421 dialect=dialect or self.dialect, 422 is_table=is_table, 423 normalize=self.normalize if normalize is None else normalize, 424 ).name 425 426 def depth(self) -> int: 427 if not self.empty and not self._depth: 428 # The columns themselves are a mapping, but we don't want to include those 429 self._depth = super().depth() - 1 430 return self._depth 431 432 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 433 """ 434 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 435 436 Args: 437 schema_type: the type we want to convert. 438 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 439 440 Returns: 441 The resulting expression type. 442 """ 443 if schema_type not in self._type_mapping_cache: 444 dialect = dialect or self.dialect 445 udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES 446 447 try: 448 expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) 449 self._type_mapping_cache[schema_type] = expression 450 except AttributeError: 451 in_dialect = f" in dialect {dialect}" if dialect else "" 452 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 453 454 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
- normalize: Whether to normalize identifier names according to the given dialect or not.
218 def __init__( 219 self, 220 schema: t.Optional[t.Dict] = None, 221 visible: t.Optional[t.Dict] = None, 222 dialect: DialectType = None, 223 normalize: bool = True, 224 ) -> None: 225 self.dialect = dialect 226 self.visible = {} if visible is None else visible 227 self.normalize = normalize 228 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 229 self._depth = 0 230 schema = {} if schema is None else schema 231 232 super().__init__(self._normalize(schema) if self.normalize else schema)
254 def add_table( 255 self, 256 table: exp.Table | str, 257 column_mapping: t.Optional[ColumnMapping] = None, 258 dialect: DialectType = None, 259 normalize: t.Optional[bool] = None, 260 match_depth: bool = True, 261 ) -> None: 262 """ 263 Register or update a table. Updates are only performed if a new column mapping is provided. 264 The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. 265 266 Args: 267 table: the `Table` expression instance or string representing the table. 268 column_mapping: a column mapping that describes the structure of the table. 269 dialect: the SQL dialect that will be used to parse `table` if it's a string. 270 normalize: whether to normalize identifiers according to the dialect of interest. 271 match_depth: whether to enforce that the table must match the schema's depth or not. 272 """ 273 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 274 275 if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): 276 raise SchemaError( 277 f"Table {normalized_table.sql(dialect=self.dialect)} must match the " 278 f"schema's nesting level: {self.depth()}." 279 ) 280 281 normalized_column_mapping = { 282 self._normalize_name(key, dialect=dialect, normalize=normalize): value 283 for key, value in ensure_column_mapping(column_mapping).items() 284 } 285 286 schema = self.find(normalized_table, raise_on_missing=False) 287 if schema and not normalized_column_mapping: 288 return 289 290 parts = self.table_parts(normalized_table) 291 292 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 293 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided. The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
- match_depth: whether to enforce that the table must match the schema's depth or not.
295 def column_names( 296 self, 297 table: exp.Table | str, 298 only_visible: bool = False, 299 dialect: DialectType = None, 300 normalize: t.Optional[bool] = None, 301 ) -> t.List[str]: 302 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 303 304 schema = self.find(normalized_table) 305 if schema is None: 306 return [] 307 308 if not only_visible or not self.visible: 309 return list(schema) 310 311 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 312 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The sequence of column names.
314 def get_column_type( 315 self, 316 table: exp.Table | str, 317 column: exp.Column | str, 318 dialect: DialectType = None, 319 normalize: t.Optional[bool] = None, 320 ) -> exp.DataType: 321 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 322 323 normalized_column_name = self._normalize_name( 324 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 325 ) 326 327 table_schema = self.find(normalized_table, raise_on_missing=False) 328 if table_schema: 329 column_type = table_schema.get(normalized_column_name) 330 331 if isinstance(column_type, exp.DataType): 332 return column_type 333 elif isinstance(column_type, str): 334 return self._to_data_type(column_type, dialect=dialect) 335 336 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
338 def has_column( 339 self, 340 table: exp.Table | str, 341 column: exp.Column | str, 342 dialect: DialectType = None, 343 normalize: t.Optional[bool] = None, 344 ) -> bool: 345 normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) 346 347 normalized_column_name = self._normalize_name( 348 column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize 349 ) 350 351 table_schema = self.find(normalized_table, raise_on_missing=False) 352 return normalized_column_name in table_schema if table_schema else False
Returns whether column
appears in table
's schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
True if the column appears in the schema, False otherwise.
Inherited Members
457def normalize_name( 458 identifier: str | exp.Identifier, 459 dialect: DialectType = None, 460 is_table: bool = False, 461 normalize: t.Optional[bool] = True, 462) -> exp.Identifier: 463 if isinstance(identifier, str): 464 identifier = exp.parse_identifier(identifier, dialect=dialect) 465 466 if not normalize: 467 return identifier 468 469 # this is used for normalize_identifier, bigquery has special rules pertaining tables 470 identifier.meta["is_table"] = is_table 471 return Dialect.get_or_raise(dialect).normalize_identifier(identifier)
481def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 482 if mapping is None: 483 return {} 484 elif isinstance(mapping, dict): 485 return mapping 486 elif isinstance(mapping, str): 487 col_name_type_strs = [x.strip() for x in mapping.split(",")] 488 return { 489 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 490 for name_type_str in col_name_type_strs 491 } 492 elif isinstance(mapping, list): 493 return {x.strip(): None for x in mapping} 494 495 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
498def flatten_schema( 499 schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None 500) -> t.List[t.List[str]]: 501 tables = [] 502 keys = keys or [] 503 depth = dict_depth(schema) - 1 if depth is None else depth 504 505 for k, v in schema.items(): 506 if depth == 1 or not isinstance(v, dict): 507 tables.append(keys + [k]) 508 elif depth >= 2: 509 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 510 511 return tables
514def nested_get( 515 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 516) -> t.Optional[t.Any]: 517 """ 518 Get a value for a nested dictionary. 519 520 Args: 521 d: the dictionary to search. 522 *path: tuples of (name, key), where: 523 `key` is the key in the dictionary to get. 524 `name` is a string to use in the error if `key` isn't found. 525 526 Returns: 527 The value or None if it doesn't exist. 528 """ 529 for name, key in path: 530 d = d.get(key) # type: ignore 531 if d is None: 532 if raise_on_missing: 533 name = "table" if name == "this" else name 534 raise ValueError(f"Unknown {name}: {key}") 535 return None 536 537 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
540def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 541 """ 542 In-place set a value for a nested dictionary 543 544 Example: 545 >>> nested_set({}, ["top_key", "second_key"], "value") 546 {'top_key': {'second_key': 'value'}} 547 548 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 549 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 550 551 Args: 552 d: dictionary to update. 553 keys: the keys that makeup the path to `value`. 554 value: the value to set in the dictionary for the given key path. 555 556 Returns: 557 The (possibly) updated dictionary. 558 """ 559 if not keys: 560 return d 561 562 if len(keys) == 1: 563 d[keys[0]] = value 564 return d 565 566 subd = d 567 for key in keys[:-1]: 568 if key not in subd: 569 subd = subd.setdefault(key, {}) 570 else: 571 subd = subd[key] 572 573 subd[keys[-1]] = value 574 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.