sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot.errors import ParseError, SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18TABLE_ARGS = ("this", "db", "catalog") 19 20T = t.TypeVar("T") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 @abc.abstractmethod 27 def add_table( 28 self, 29 table: exp.Table | str, 30 column_mapping: t.Optional[ColumnMapping] = None, 31 dialect: DialectType = None, 32 ) -> None: 33 """ 34 Register or update a table. Some implementing classes may require column information to also be provided. 35 36 Args: 37 table: the `Table` expression instance or string representing the table. 38 column_mapping: a column mapping that describes the structure of the table. 39 dialect: the SQL dialect that will be used to parse `table` if it's a string. 40 """ 41 42 @abc.abstractmethod 43 def column_names( 44 self, 45 table: exp.Table | str, 46 only_visible: bool = False, 47 dialect: DialectType = None, 48 ) -> t.List[str]: 49 """ 50 Get the column names for a table. 51 52 Args: 53 table: the `Table` expression instance. 54 only_visible: whether to include invisible columns. 55 dialect: the SQL dialect that will be used to parse `table` if it's a string. 56 57 Returns: 58 The list of column names. 59 """ 60 61 @abc.abstractmethod 62 def get_column_type( 63 self, 64 table: exp.Table | str, 65 column: exp.Column, 66 dialect: DialectType = None, 67 ) -> exp.DataType: 68 """ 69 Get the `sqlglot.exp.DataType` type of a column in the schema. 70 71 Args: 72 table: the source table. 73 column: the target column. 74 dialect: the SQL dialect that will be used to parse `table` if it's a string. 75 76 Returns: 77 The resulting column type. 78 """ 79 80 @property 81 @abc.abstractmethod 82 def supported_table_args(self) -> t.Tuple[str, ...]: 83 """ 84 Table arguments this schema support, e.g. `("this", "db", "catalog")` 85 """ 86 87 @property 88 def empty(self) -> bool: 89 """Returns whether or not the schema is empty.""" 90 return True 91 92 93class AbstractMappingSchema(t.Generic[T]): 94 def __init__( 95 self, 96 mapping: t.Optional[t.Dict] = None, 97 ) -> None: 98 self.mapping = mapping or {} 99 self.mapping_trie = new_trie( 100 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 101 ) 102 self._supported_table_args: t.Tuple[str, ...] = tuple() 103 104 @property 105 def empty(self) -> bool: 106 return not self.mapping 107 108 def _depth(self) -> int: 109 return dict_depth(self.mapping) 110 111 @property 112 def supported_table_args(self) -> t.Tuple[str, ...]: 113 if not self._supported_table_args and self.mapping: 114 depth = self._depth() 115 116 if not depth: # None 117 self._supported_table_args = tuple() 118 elif 1 <= depth <= 3: 119 self._supported_table_args = TABLE_ARGS[:depth] 120 else: 121 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 122 123 return self._supported_table_args 124 125 def table_parts(self, table: exp.Table) -> t.List[str]: 126 if isinstance(table.this, exp.ReadCSV): 127 return [table.this.name] 128 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 129 130 def find( 131 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 132 ) -> t.Optional[T]: 133 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 134 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 135 136 if value == 0: 137 return None 138 139 if value == 1: 140 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 141 142 if len(possibilities) == 1: 143 parts.extend(possibilities[0]) 144 else: 145 message = ", ".join(".".join(parts) for parts in possibilities) 146 if raise_on_missing: 147 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 148 return None 149 150 return self.nested_get(parts, raise_on_missing=raise_on_missing) 151 152 def nested_get( 153 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 154 ) -> t.Optional[t.Any]: 155 return nested_get( 156 d or self.mapping, 157 *zip(self.supported_table_args, reversed(parts)), 158 raise_on_missing=raise_on_missing, 159 ) 160 161 162class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 163 """ 164 Schema based on a nested mapping. 165 166 Args: 167 schema: Mapping in one of the following forms: 168 1. {table: {col: type}} 169 2. {db: {table: {col: type}}} 170 3. {catalog: {db: {table: {col: type}}}} 171 4. None - Tables will be added later 172 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 173 are assumed to be visible. The nesting should mirror that of the schema: 174 1. {table: set(*cols)}} 175 2. {db: {table: set(*cols)}}} 176 3. {catalog: {db: {table: set(*cols)}}}} 177 dialect: The dialect to be used for custom type mappings & parsing string arguments. 178 """ 179 180 def __init__( 181 self, 182 schema: t.Optional[t.Dict] = None, 183 visible: t.Optional[t.Dict] = None, 184 dialect: DialectType = None, 185 ) -> None: 186 self.dialect = dialect 187 self.visible = visible or {} 188 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 189 190 super().__init__(self._normalize(schema or {})) 191 192 @classmethod 193 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 194 return MappingSchema( 195 schema=mapping_schema.mapping, 196 visible=mapping_schema.visible, 197 dialect=mapping_schema.dialect, 198 ) 199 200 def copy(self, **kwargs) -> MappingSchema: 201 return MappingSchema( 202 **{ # type: ignore 203 "schema": self.mapping.copy(), 204 "visible": self.visible.copy(), 205 "dialect": self.dialect, 206 **kwargs, 207 } 208 ) 209 210 def add_table( 211 self, 212 table: exp.Table | str, 213 column_mapping: t.Optional[ColumnMapping] = None, 214 dialect: DialectType = None, 215 ) -> None: 216 """ 217 Register or update a table. Updates are only performed if a new column mapping is provided. 218 219 Args: 220 table: the `Table` expression instance or string representing the table. 221 column_mapping: a column mapping that describes the structure of the table. 222 dialect: the SQL dialect that will be used to parse `table` if it's a string. 223 """ 224 normalized_table = self._normalize_table( 225 self._ensure_table(table, dialect=dialect), dialect=dialect 226 ) 227 normalized_column_mapping = { 228 self._normalize_name(key, dialect=dialect): value 229 for key, value in ensure_column_mapping(column_mapping).items() 230 } 231 232 schema = self.find(normalized_table, raise_on_missing=False) 233 if schema and not normalized_column_mapping: 234 return 235 236 parts = self.table_parts(normalized_table) 237 238 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 239 new_trie([parts], self.mapping_trie) 240 241 def column_names( 242 self, 243 table: exp.Table | str, 244 only_visible: bool = False, 245 dialect: DialectType = None, 246 ) -> t.List[str]: 247 normalized_table = self._normalize_table( 248 self._ensure_table(table, dialect=dialect), dialect=dialect 249 ) 250 251 schema = self.find(normalized_table) 252 if schema is None: 253 return [] 254 255 if not only_visible or not self.visible: 256 return list(schema) 257 258 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 259 return [col for col in schema if col in visible] 260 261 def get_column_type( 262 self, 263 table: exp.Table | str, 264 column: exp.Column, 265 dialect: DialectType = None, 266 ) -> exp.DataType: 267 normalized_table = self._normalize_table( 268 self._ensure_table(table, dialect=dialect), dialect=dialect 269 ) 270 normalized_column_name = self._normalize_name( 271 column if isinstance(column, str) else column.this, dialect=dialect 272 ) 273 274 table_schema = self.find(normalized_table, raise_on_missing=False) 275 if table_schema: 276 column_type = table_schema.get(normalized_column_name) 277 278 if isinstance(column_type, exp.DataType): 279 return column_type 280 elif isinstance(column_type, str): 281 return self._to_data_type(column_type.upper(), dialect=dialect) 282 283 raise SchemaError(f"Unknown column type '{column_type}'") 284 285 return exp.DataType.build("unknown") 286 287 def _normalize(self, schema: t.Dict) -> t.Dict: 288 """ 289 Converts all identifiers in the schema into lowercase, unless they're quoted. 290 291 Args: 292 schema: the schema to normalize. 293 294 Returns: 295 The normalized schema mapping. 296 """ 297 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 298 299 normalized_mapping: t.Dict = {} 300 for keys in flattened_schema: 301 columns = nested_get(schema, *zip(keys, keys)) 302 assert columns is not None 303 304 normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] 305 for column_name, column_type in columns.items(): 306 nested_set( 307 normalized_mapping, 308 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 309 column_type, 310 ) 311 312 return normalized_mapping 313 314 def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: 315 normalized_table = table.copy() 316 317 for arg in TABLE_ARGS: 318 value = normalized_table.args.get(arg) 319 if isinstance(value, (str, exp.Identifier)): 320 normalized_table.set(arg, self._normalize_name(value, dialect=dialect)) 321 322 return normalized_table 323 324 def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: 325 dialect = dialect or self.dialect 326 327 try: 328 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 329 except ParseError: 330 return name if isinstance(name, str) else name.name 331 332 return identifier.name if identifier.quoted else identifier.name.lower() 333 334 def _depth(self) -> int: 335 # The columns themselves are a mapping, but we don't want to include those 336 return super()._depth() - 1 337 338 def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: 339 if isinstance(table, exp.Table): 340 return table 341 342 dialect = dialect or self.dialect 343 parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table) 344 345 if not parsed_table: 346 in_dialect = f" in dialect {dialect}" if dialect else "" 347 raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.") 348 349 return parsed_table 350 351 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 352 """ 353 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 354 355 Args: 356 schema_type: the type we want to convert. 357 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 358 359 Returns: 360 The resulting expression type. 361 """ 362 if schema_type not in self._type_mapping_cache: 363 dialect = dialect or self.dialect 364 365 try: 366 expression = exp.DataType.build(schema_type, dialect=dialect) 367 self._type_mapping_cache[schema_type] = expression 368 except AttributeError: 369 in_dialect = f" in dialect {dialect}" if dialect else "" 370 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 371 372 return self._type_mapping_cache[schema_type] 373 374 375def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: 376 if isinstance(schema, Schema): 377 return schema 378 379 return MappingSchema(schema, dialect=dialect) 380 381 382def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 383 if mapping is None: 384 return {} 385 elif isinstance(mapping, dict): 386 return mapping 387 elif isinstance(mapping, str): 388 col_name_type_strs = [x.strip() for x in mapping.split(",")] 389 return { 390 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 391 for name_type_str in col_name_type_strs 392 } 393 # Check if mapping looks like a DataFrame StructType 394 elif hasattr(mapping, "simpleString"): 395 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 396 elif isinstance(mapping, list): 397 return {x.strip(): None for x in mapping} 398 399 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 400 401 402def flatten_schema( 403 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 404) -> t.List[t.List[str]]: 405 tables = [] 406 keys = keys or [] 407 408 for k, v in schema.items(): 409 if depth >= 2: 410 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 411 elif depth == 1: 412 tables.append(keys + [k]) 413 414 return tables 415 416 417def nested_get( 418 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 419) -> t.Optional[t.Any]: 420 """ 421 Get a value for a nested dictionary. 422 423 Args: 424 d: the dictionary to search. 425 *path: tuples of (name, key), where: 426 `key` is the key in the dictionary to get. 427 `name` is a string to use in the error if `key` isn't found. 428 429 Returns: 430 The value or None if it doesn't exist. 431 """ 432 for name, key in path: 433 d = d.get(key) # type: ignore 434 if d is None: 435 if raise_on_missing: 436 name = "table" if name == "this" else name 437 raise ValueError(f"Unknown {name}: {key}") 438 return None 439 440 return d 441 442 443def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 444 """ 445 In-place set a value for a nested dictionary 446 447 Example: 448 >>> nested_set({}, ["top_key", "second_key"], "value") 449 {'top_key': {'second_key': 'value'}} 450 451 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 452 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 453 454 Args: 455 d: dictionary to update. 456 keys: the keys that makeup the path to `value`. 457 value: the value to set in the dictionary for the given key path. 458 459 Returns: 460 The (possibly) updated dictionary. 461 """ 462 if not keys: 463 return d 464 465 if len(keys) == 1: 466 d[keys[0]] = value 467 return d 468 469 subd = d 470 for key in keys[:-1]: 471 if key not in subd: 472 subd = subd.setdefault(key, {}) 473 else: 474 subd = subd[key] 475 476 subd[keys[-1]] = value 477 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 @abc.abstractmethod 28 def add_table( 29 self, 30 table: exp.Table | str, 31 column_mapping: t.Optional[ColumnMapping] = None, 32 dialect: DialectType = None, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 """ 42 43 @abc.abstractmethod 44 def column_names( 45 self, 46 table: exp.Table | str, 47 only_visible: bool = False, 48 dialect: DialectType = None, 49 ) -> t.List[str]: 50 """ 51 Get the column names for a table. 52 53 Args: 54 table: the `Table` expression instance. 55 only_visible: whether to include invisible columns. 56 dialect: the SQL dialect that will be used to parse `table` if it's a string. 57 58 Returns: 59 The list of column names. 60 """ 61 62 @abc.abstractmethod 63 def get_column_type( 64 self, 65 table: exp.Table | str, 66 column: exp.Column, 67 dialect: DialectType = None, 68 ) -> exp.DataType: 69 """ 70 Get the `sqlglot.exp.DataType` type of a column in the schema. 71 72 Args: 73 table: the source table. 74 column: the target column. 75 dialect: the SQL dialect that will be used to parse `table` if it's a string. 76 77 Returns: 78 The resulting column type. 79 """ 80 81 @property 82 @abc.abstractmethod 83 def supported_table_args(self) -> t.Tuple[str, ...]: 84 """ 85 Table arguments this schema support, e.g. `("this", "db", "catalog")` 86 """ 87 88 @property 89 def empty(self) -> bool: 90 """Returns whether or not the schema is empty.""" 91 return True
Abstract base class for database schemas
27 @abc.abstractmethod 28 def add_table( 29 self, 30 table: exp.Table | str, 31 column_mapping: t.Optional[ColumnMapping] = None, 32 dialect: DialectType = None, 33 ) -> None: 34 """ 35 Register or update a table. Some implementing classes may require column information to also be provided. 36 37 Args: 38 table: the `Table` expression instance or string representing the table. 39 column_mapping: a column mapping that describes the structure of the table. 40 dialect: the SQL dialect that will be used to parse `table` if it's a string. 41 """
Register or update a table. Some implementing classes may require column information to also be provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
43 @abc.abstractmethod 44 def column_names( 45 self, 46 table: exp.Table | str, 47 only_visible: bool = False, 48 dialect: DialectType = None, 49 ) -> t.List[str]: 50 """ 51 Get the column names for a table. 52 53 Args: 54 table: the `Table` expression instance. 55 only_visible: whether to include invisible columns. 56 dialect: the SQL dialect that will be used to parse `table` if it's a string. 57 58 Returns: 59 The list of column names. 60 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The list of column names.
62 @abc.abstractmethod 63 def get_column_type( 64 self, 65 table: exp.Table | str, 66 column: exp.Column, 67 dialect: DialectType = None, 68 ) -> exp.DataType: 69 """ 70 Get the `sqlglot.exp.DataType` type of a column in the schema. 71 72 Args: 73 table: the source table. 74 column: the target column. 75 dialect: the SQL dialect that will be used to parse `table` if it's a string. 76 77 Returns: 78 The resulting column type. 79 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The resulting column type.
94class AbstractMappingSchema(t.Generic[T]): 95 def __init__( 96 self, 97 mapping: t.Optional[t.Dict] = None, 98 ) -> None: 99 self.mapping = mapping or {} 100 self.mapping_trie = new_trie( 101 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 102 ) 103 self._supported_table_args: t.Tuple[str, ...] = tuple() 104 105 @property 106 def empty(self) -> bool: 107 return not self.mapping 108 109 def _depth(self) -> int: 110 return dict_depth(self.mapping) 111 112 @property 113 def supported_table_args(self) -> t.Tuple[str, ...]: 114 if not self._supported_table_args and self.mapping: 115 depth = self._depth() 116 117 if not depth: # None 118 self._supported_table_args = tuple() 119 elif 1 <= depth <= 3: 120 self._supported_table_args = TABLE_ARGS[:depth] 121 else: 122 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 123 124 return self._supported_table_args 125 126 def table_parts(self, table: exp.Table) -> t.List[str]: 127 if isinstance(table.this, exp.ReadCSV): 128 return [table.this.name] 129 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 130 131 def find( 132 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 133 ) -> t.Optional[T]: 134 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 135 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 136 137 if value == 0: 138 return None 139 140 if value == 1: 141 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 142 143 if len(possibilities) == 1: 144 parts.extend(possibilities[0]) 145 else: 146 message = ", ".join(".".join(parts) for parts in possibilities) 147 if raise_on_missing: 148 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 149 return None 150 151 return self.nested_get(parts, raise_on_missing=raise_on_missing) 152 153 def nested_get( 154 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 155 ) -> t.Optional[t.Any]: 156 return nested_get( 157 d or self.mapping, 158 *zip(self.supported_table_args, reversed(parts)), 159 raise_on_missing=raise_on_missing, 160 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
131 def find( 132 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 133 ) -> t.Optional[T]: 134 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 135 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 136 137 if value == 0: 138 return None 139 140 if value == 1: 141 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 142 143 if len(possibilities) == 1: 144 parts.extend(possibilities[0]) 145 else: 146 message = ", ".join(".".join(parts) for parts in possibilities) 147 if raise_on_missing: 148 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 149 return None 150 151 return self.nested_get(parts, raise_on_missing=raise_on_missing)
163class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 164 """ 165 Schema based on a nested mapping. 166 167 Args: 168 schema: Mapping in one of the following forms: 169 1. {table: {col: type}} 170 2. {db: {table: {col: type}}} 171 3. {catalog: {db: {table: {col: type}}}} 172 4. None - Tables will be added later 173 visible: Optional mapping of which columns in the schema are visible. If not provided, all columns 174 are assumed to be visible. The nesting should mirror that of the schema: 175 1. {table: set(*cols)}} 176 2. {db: {table: set(*cols)}}} 177 3. {catalog: {db: {table: set(*cols)}}}} 178 dialect: The dialect to be used for custom type mappings & parsing string arguments. 179 """ 180 181 def __init__( 182 self, 183 schema: t.Optional[t.Dict] = None, 184 visible: t.Optional[t.Dict] = None, 185 dialect: DialectType = None, 186 ) -> None: 187 self.dialect = dialect 188 self.visible = visible or {} 189 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 190 191 super().__init__(self._normalize(schema or {})) 192 193 @classmethod 194 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 195 return MappingSchema( 196 schema=mapping_schema.mapping, 197 visible=mapping_schema.visible, 198 dialect=mapping_schema.dialect, 199 ) 200 201 def copy(self, **kwargs) -> MappingSchema: 202 return MappingSchema( 203 **{ # type: ignore 204 "schema": self.mapping.copy(), 205 "visible": self.visible.copy(), 206 "dialect": self.dialect, 207 **kwargs, 208 } 209 ) 210 211 def add_table( 212 self, 213 table: exp.Table | str, 214 column_mapping: t.Optional[ColumnMapping] = None, 215 dialect: DialectType = None, 216 ) -> None: 217 """ 218 Register or update a table. Updates are only performed if a new column mapping is provided. 219 220 Args: 221 table: the `Table` expression instance or string representing the table. 222 column_mapping: a column mapping that describes the structure of the table. 223 dialect: the SQL dialect that will be used to parse `table` if it's a string. 224 """ 225 normalized_table = self._normalize_table( 226 self._ensure_table(table, dialect=dialect), dialect=dialect 227 ) 228 normalized_column_mapping = { 229 self._normalize_name(key, dialect=dialect): value 230 for key, value in ensure_column_mapping(column_mapping).items() 231 } 232 233 schema = self.find(normalized_table, raise_on_missing=False) 234 if schema and not normalized_column_mapping: 235 return 236 237 parts = self.table_parts(normalized_table) 238 239 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 240 new_trie([parts], self.mapping_trie) 241 242 def column_names( 243 self, 244 table: exp.Table | str, 245 only_visible: bool = False, 246 dialect: DialectType = None, 247 ) -> t.List[str]: 248 normalized_table = self._normalize_table( 249 self._ensure_table(table, dialect=dialect), dialect=dialect 250 ) 251 252 schema = self.find(normalized_table) 253 if schema is None: 254 return [] 255 256 if not only_visible or not self.visible: 257 return list(schema) 258 259 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 260 return [col for col in schema if col in visible] 261 262 def get_column_type( 263 self, 264 table: exp.Table | str, 265 column: exp.Column, 266 dialect: DialectType = None, 267 ) -> exp.DataType: 268 normalized_table = self._normalize_table( 269 self._ensure_table(table, dialect=dialect), dialect=dialect 270 ) 271 normalized_column_name = self._normalize_name( 272 column if isinstance(column, str) else column.this, dialect=dialect 273 ) 274 275 table_schema = self.find(normalized_table, raise_on_missing=False) 276 if table_schema: 277 column_type = table_schema.get(normalized_column_name) 278 279 if isinstance(column_type, exp.DataType): 280 return column_type 281 elif isinstance(column_type, str): 282 return self._to_data_type(column_type.upper(), dialect=dialect) 283 284 raise SchemaError(f"Unknown column type '{column_type}'") 285 286 return exp.DataType.build("unknown") 287 288 def _normalize(self, schema: t.Dict) -> t.Dict: 289 """ 290 Converts all identifiers in the schema into lowercase, unless they're quoted. 291 292 Args: 293 schema: the schema to normalize. 294 295 Returns: 296 The normalized schema mapping. 297 """ 298 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 299 300 normalized_mapping: t.Dict = {} 301 for keys in flattened_schema: 302 columns = nested_get(schema, *zip(keys, keys)) 303 assert columns is not None 304 305 normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] 306 for column_name, column_type in columns.items(): 307 nested_set( 308 normalized_mapping, 309 normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], 310 column_type, 311 ) 312 313 return normalized_mapping 314 315 def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: 316 normalized_table = table.copy() 317 318 for arg in TABLE_ARGS: 319 value = normalized_table.args.get(arg) 320 if isinstance(value, (str, exp.Identifier)): 321 normalized_table.set(arg, self._normalize_name(value, dialect=dialect)) 322 323 return normalized_table 324 325 def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: 326 dialect = dialect or self.dialect 327 328 try: 329 identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) 330 except ParseError: 331 return name if isinstance(name, str) else name.name 332 333 return identifier.name if identifier.quoted else identifier.name.lower() 334 335 def _depth(self) -> int: 336 # The columns themselves are a mapping, but we don't want to include those 337 return super()._depth() - 1 338 339 def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: 340 if isinstance(table, exp.Table): 341 return table 342 343 dialect = dialect or self.dialect 344 parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table) 345 346 if not parsed_table: 347 in_dialect = f" in dialect {dialect}" if dialect else "" 348 raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.") 349 350 return parsed_table 351 352 def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: 353 """ 354 Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. 355 356 Args: 357 schema_type: the type we want to convert. 358 dialect: the SQL dialect that will be used to parse `schema_type`, if needed. 359 360 Returns: 361 The resulting expression type. 362 """ 363 if schema_type not in self._type_mapping_cache: 364 dialect = dialect or self.dialect 365 366 try: 367 expression = exp.DataType.build(schema_type, dialect=dialect) 368 self._type_mapping_cache[schema_type] = expression 369 except AttributeError: 370 in_dialect = f" in dialect {dialect}" if dialect else "" 371 raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") 372 373 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema: Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect: The dialect to be used for custom type mappings & parsing string arguments.
181 def __init__( 182 self, 183 schema: t.Optional[t.Dict] = None, 184 visible: t.Optional[t.Dict] = None, 185 dialect: DialectType = None, 186 ) -> None: 187 self.dialect = dialect 188 self.visible = visible or {} 189 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 190 191 super().__init__(self._normalize(schema or {}))
211 def add_table( 212 self, 213 table: exp.Table | str, 214 column_mapping: t.Optional[ColumnMapping] = None, 215 dialect: DialectType = None, 216 ) -> None: 217 """ 218 Register or update a table. Updates are only performed if a new column mapping is provided. 219 220 Args: 221 table: the `Table` expression instance or string representing the table. 222 column_mapping: a column mapping that describes the structure of the table. 223 dialect: the SQL dialect that will be used to parse `table` if it's a string. 224 """ 225 normalized_table = self._normalize_table( 226 self._ensure_table(table, dialect=dialect), dialect=dialect 227 ) 228 normalized_column_mapping = { 229 self._normalize_name(key, dialect=dialect): value 230 for key, value in ensure_column_mapping(column_mapping).items() 231 } 232 233 schema = self.find(normalized_table, raise_on_missing=False) 234 if schema and not normalized_column_mapping: 235 return 236 237 parts = self.table_parts(normalized_table) 238 239 nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) 240 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
242 def column_names( 243 self, 244 table: exp.Table | str, 245 only_visible: bool = False, 246 dialect: DialectType = None, 247 ) -> t.List[str]: 248 normalized_table = self._normalize_table( 249 self._ensure_table(table, dialect=dialect), dialect=dialect 250 ) 251 252 schema = self.find(normalized_table) 253 if schema is None: 254 return [] 255 256 if not only_visible or not self.visible: 257 return list(schema) 258 259 visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] 260 return [col for col in schema if col in visible]
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The list of column names.
262 def get_column_type( 263 self, 264 table: exp.Table | str, 265 column: exp.Column, 266 dialect: DialectType = None, 267 ) -> exp.DataType: 268 normalized_table = self._normalize_table( 269 self._ensure_table(table, dialect=dialect), dialect=dialect 270 ) 271 normalized_column_name = self._normalize_name( 272 column if isinstance(column, str) else column.this, dialect=dialect 273 ) 274 275 table_schema = self.find(normalized_table, raise_on_missing=False) 276 if table_schema: 277 column_type = table_schema.get(normalized_column_name) 278 279 if isinstance(column_type, exp.DataType): 280 return column_type 281 elif isinstance(column_type, str): 282 return self._to_data_type(column_type.upper(), dialect=dialect) 283 284 raise SchemaError(f"Unknown column type '{column_type}'") 285 286 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
- dialect: the SQL dialect that will be used to parse
table
if it's a string.
Returns:
The resulting column type.
Inherited Members
383def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 384 if mapping is None: 385 return {} 386 elif isinstance(mapping, dict): 387 return mapping 388 elif isinstance(mapping, str): 389 col_name_type_strs = [x.strip() for x in mapping.split(",")] 390 return { 391 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 392 for name_type_str in col_name_type_strs 393 } 394 # Check if mapping looks like a DataFrame StructType 395 elif hasattr(mapping, "simpleString"): 396 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} 397 elif isinstance(mapping, list): 398 return {x.strip(): None for x in mapping} 399 400 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
403def flatten_schema( 404 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 405) -> t.List[t.List[str]]: 406 tables = [] 407 keys = keys or [] 408 409 for k, v in schema.items(): 410 if depth >= 2: 411 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 412 elif depth == 1: 413 tables.append(keys + [k]) 414 415 return tables
418def nested_get( 419 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 420) -> t.Optional[t.Any]: 421 """ 422 Get a value for a nested dictionary. 423 424 Args: 425 d: the dictionary to search. 426 *path: tuples of (name, key), where: 427 `key` is the key in the dictionary to get. 428 `name` is a string to use in the error if `key` isn't found. 429 430 Returns: 431 The value or None if it doesn't exist. 432 """ 433 for name, key in path: 434 d = d.get(key) # type: ignore 435 if d is None: 436 if raise_on_missing: 437 name = "table" if name == "this" else name 438 raise ValueError(f"Unknown {name}: {key}") 439 return None 440 441 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
444def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 445 """ 446 In-place set a value for a nested dictionary 447 448 Example: 449 >>> nested_set({}, ["top_key", "second_key"], "value") 450 {'top_key': {'second_key': 'value'}} 451 452 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 453 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 454 455 Args: 456 d: dictionary to update. 457 keys: the keys that makeup the path to `value`. 458 value: the value to set in the dictionary for the given key path. 459 460 Returns: 461 The (possibly) updated dictionary. 462 """ 463 if not keys: 464 return d 465 466 if len(keys) == 1: 467 d[keys[0]] = value 468 return d 469 470 subd = d 471 for key in keys[:-1]: 472 if key not in subd: 473 subd = subd.setdefault(key, {}) 474 else: 475 subd = subd[key] 476 477 subd[keys[-1]] = value 478 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.