Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection, Set
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 57
 58
 59@t.overload
 60def ensure_list(value: T) -> t.List[T]: ...
 61
 62
 63def ensure_list(value):
 64    """
 65    Ensures that a value is a list, otherwise casts or wraps it into one.
 66
 67    Args:
 68        value: The value of interest.
 69
 70    Returns:
 71        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 72    """
 73    if value is None:
 74        return []
 75    if isinstance(value, (list, tuple)):
 76        return list(value)
 77
 78    return [value]
 79
 80
 81@t.overload
 82def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 83
 84
 85@t.overload
 86def ensure_collection(value: T) -> t.Collection[T]: ...
 87
 88
 89def ensure_collection(value):
 90    """
 91    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 92
 93    Args:
 94        value: The value of interest.
 95
 96    Returns:
 97        The value if it's a collection, or else the value wrapped in a list.
 98    """
 99    if value is None:
100        return []
101    return (
102        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
103    )
104
105
106def csv(*args: str, sep: str = ", ") -> str:
107    """
108    Formats any number of string arguments as CSV.
109
110    Args:
111        args: The string arguments to format.
112        sep: The argument separator.
113
114    Returns:
115        The arguments formatted as a CSV string.
116    """
117    return sep.join(arg for arg in args if arg)
118
119
120def subclasses(
121    module_name: str,
122    classes: t.Type | t.Tuple[t.Type, ...],
123    exclude: t.Type | t.Tuple[t.Type, ...] = (),
124) -> t.List[t.Type]:
125    """
126    Returns all subclasses for a collection of classes, possibly excluding some of them.
127
128    Args:
129        module_name: The name of the module to search for subclasses in.
130        classes: Class(es) we want to find the subclasses of.
131        exclude: Class(es) we want to exclude from the returned list.
132
133    Returns:
134        The target subclasses.
135    """
136    return [
137        obj
138        for _, obj in inspect.getmembers(
139            sys.modules[module_name],
140            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
141        )
142    ]
143
144
145def apply_index_offset(
146    this: exp.Expression,
147    expressions: t.List[E],
148    offset: int,
149) -> t.List[E]:
150    """
151    Applies an offset to a given integer literal expression.
152
153    Args:
154        this: The target of the index.
155        expressions: The expression the offset will be applied to, wrapped in a list.
156        offset: The offset that will be applied.
157
158    Returns:
159        The original expression with the offset applied to it, wrapped in a list. If the provided
160        `expressions` argument contains more than one expression, it's returned unaffected.
161    """
162    if not offset or len(expressions) != 1:
163        return expressions
164
165    expression = expressions[0]
166
167    from sqlglot import exp
168    from sqlglot.optimizer.annotate_types import annotate_types
169    from sqlglot.optimizer.simplify import simplify
170
171    if not this.type:
172        annotate_types(this)
173
174    if t.cast(exp.DataType, this.type).this not in (
175        exp.DataType.Type.UNKNOWN,
176        exp.DataType.Type.ARRAY,
177    ):
178        return expressions
179
180    if not expression.type:
181        annotate_types(expression)
182    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
183        logger.warning("Applying array index offset (%s)", offset)
184        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
185        return [expression]
186
187    return expressions
188
189
190def camel_to_snake_case(name: str) -> str:
191    """Converts `name` from camelCase to snake_case and returns the result."""
192    return CAMEL_CASE_PATTERN.sub("_", name).upper()
193
194
195def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
196    """
197    Applies a transformation to a given expression until a fix point is reached.
198
199    Args:
200        expression: The expression to be transformed.
201        func: The transformation to be applied.
202
203    Returns:
204        The transformed expression.
205    """
206    while True:
207        for n, *_ in reversed(tuple(expression.walk())):
208            n._hash = hash(n)
209
210        start = hash(expression)
211        expression = func(expression)
212
213        for n, *_ in expression.walk():
214            n._hash = None
215        if start == hash(expression):
216            break
217
218    return expression
219
220
221def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
222    """
223    Sorts a given directed acyclic graph in topological order.
224
225    Args:
226        dag: The graph to be sorted.
227
228    Returns:
229        A list that contains all of the graph's nodes in topological order.
230    """
231    result = []
232
233    for node, deps in tuple(dag.items()):
234        for dep in deps:
235            if dep not in dag:
236                dag[dep] = set()
237
238    while dag:
239        current = {node for node, deps in dag.items() if not deps}
240
241        if not current:
242            raise ValueError("Cycle error")
243
244        for node in current:
245            dag.pop(node)
246
247        for deps in dag.values():
248            deps -= current
249
250        result.extend(sorted(current))  # type: ignore
251
252    return result
253
254
255def open_file(file_name: str) -> t.TextIO:
256    """Open a file that may be compressed as gzip and return it in universal newline mode."""
257    with open(file_name, "rb") as f:
258        gzipped = f.read(2) == b"\x1f\x8b"
259
260    if gzipped:
261        import gzip
262
263        return gzip.open(file_name, "rt", newline="")
264
265    return open(file_name, encoding="utf-8", newline="")
266
267
268@contextmanager
269def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
270    """
271    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
272
273    Args:
274        read_csv: A `ReadCSV` function call.
275
276    Yields:
277        A python csv reader.
278    """
279    args = read_csv.expressions
280    file = open_file(read_csv.name)
281
282    delimiter = ","
283    args = iter(arg.name for arg in args)  # type: ignore
284    for k, v in zip(args, args):
285        if k == "delimiter":
286            delimiter = v
287
288    try:
289        import csv as csv_
290
291        yield csv_.reader(file, delimiter=delimiter)
292    finally:
293        file.close()
294
295
296def find_new_name(taken: t.Collection[str], base: str) -> str:
297    """
298    Searches for a new name.
299
300    Args:
301        taken: A collection of taken names.
302        base: Base name to alter.
303
304    Returns:
305        The new, available name.
306    """
307    if base not in taken:
308        return base
309
310    i = 2
311    new = f"{base}_{i}"
312    while new in taken:
313        i += 1
314        new = f"{base}_{i}"
315
316    return new
317
318
319def is_int(text: str) -> bool:
320    try:
321        int(text)
322        return True
323    except ValueError:
324        return False
325
326
327def name_sequence(prefix: str) -> t.Callable[[], str]:
328    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
329    sequence = count()
330    return lambda: f"{prefix}{next(sequence)}"
331
332
333def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
334    """Returns a dictionary created from an object's attributes."""
335    return {
336        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
337        **kwargs,
338    }
339
340
341def split_num_words(
342    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
343) -> t.List[t.Optional[str]]:
344    """
345    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
346
347    Args:
348        value: The value to be split.
349        sep: The value to use to split on.
350        min_num_words: The minimum number of words that are going to be in the result.
351        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
352
353    Examples:
354        >>> split_num_words("db.table", ".", 3)
355        [None, 'db', 'table']
356        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
357        ['db', 'table', None]
358        >>> split_num_words("db.table", ".", 1)
359        ['db', 'table']
360
361    Returns:
362        The list of words returned by `split`, possibly augmented by a number of `None` values.
363    """
364    words = value.split(sep)
365    if fill_from_start:
366        return [None] * (min_num_words - len(words)) + words
367    return words + [None] * (min_num_words - len(words))
368
369
370def is_iterable(value: t.Any) -> bool:
371    """
372    Checks if the value is an iterable, excluding the types `str` and `bytes`.
373
374    Examples:
375        >>> is_iterable([1,2])
376        True
377        >>> is_iterable("test")
378        False
379
380    Args:
381        value: The value to check if it is an iterable.
382
383    Returns:
384        A `bool` value indicating if it is an iterable.
385    """
386    from sqlglot import Expression
387
388    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
389
390
391def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
392    """
393    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
394    type `str` and `bytes` are not regarded as iterables.
395
396    Examples:
397        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
398        [1, 2, 3, 4, 5, 'bla']
399        >>> list(flatten([1, 2, 3]))
400        [1, 2, 3]
401
402    Args:
403        values: The value to be flattened.
404
405    Yields:
406        Non-iterable elements in `values`.
407    """
408    for value in values:
409        if is_iterable(value):
410            yield from flatten(value)
411        else:
412            yield value
413
414
415def dict_depth(d: t.Dict) -> int:
416    """
417    Get the nesting depth of a dictionary.
418
419    Example:
420        >>> dict_depth(None)
421        0
422        >>> dict_depth({})
423        1
424        >>> dict_depth({"a": "b"})
425        1
426        >>> dict_depth({"a": {}})
427        2
428        >>> dict_depth({"a": {"b": {}}})
429        3
430    """
431    try:
432        return 1 + dict_depth(next(iter(d.values())))
433    except AttributeError:
434        # d doesn't have attribute "values"
435        return 0
436    except StopIteration:
437        # d.values() returns an empty sequence
438        return 1
439
440
441def first(it: t.Iterable[T]) -> T:
442    """Returns the first element from an iterable (useful for sets)."""
443    return next(i for i in it)
444
445
446def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
447    """
448    Merges a sequence of ranges, represented as tuples (low, high) whose values
449    belong to some totally-ordered set.
450
451    Example:
452        >>> merge_ranges([(1, 3), (2, 6)])
453        [(1, 6)]
454    """
455    if not ranges:
456        return []
457
458    ranges = sorted(ranges)
459
460    merged = [ranges[0]]
461
462    for start, end in ranges[1:]:
463        last_start, last_end = merged[-1]
464
465        if start <= last_end:
466            merged[-1] = (last_start, max(last_end, end))
467        else:
468            merged.append((start, end))
469
470    return merged
471
472
473def is_iso_date(text: str) -> bool:
474    try:
475        datetime.date.fromisoformat(text)
476        return True
477    except ValueError:
478        return False
479
480
481def is_iso_datetime(text: str) -> bool:
482    try:
483        datetime.datetime.fromisoformat(text)
484        return True
485    except ValueError:
486        return False
487
488
489# Interval units that operate on date components
490DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
491
492
493def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
494    return expression is not None and expression.name.lower() in DATE_UNITS
495
496
497K = t.TypeVar("K")
498V = t.TypeVar("V")
499
500
501class SingleValuedMapping(t.Mapping[K, V]):
502    """
503    Mapping where all keys return the same value.
504
505    This rigamarole is meant to avoid copying keys, which was originally intended
506    as an optimization while qualifying columns for tables with lots of columns.
507    """
508
509    def __init__(self, keys: t.Collection[K], value: V):
510        self._keys = keys if isinstance(keys, Set) else set(keys)
511        self._value = value
512
513    def __getitem__(self, key: K) -> V:
514        if key in self._keys:
515            return self._value
516        raise KeyError(key)
517
518    def __len__(self) -> int:
519        return len(self._keys)
520
521    def __iter__(self) -> t.Iterator[K]:
522        return iter(self._keys)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

Inherited Members
builtins.property
property
getter
setter
deleter
fget
fset
fdel
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
64def ensure_list(value):
65    """
66    Ensures that a value is a list, otherwise casts or wraps it into one.
67
68    Args:
69        value: The value of interest.
70
71    Returns:
72        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
73    """
74    if value is None:
75        return []
76    if isinstance(value, (list, tuple)):
77        return list(value)
78
79    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 90def ensure_collection(value):
 91    """
 92    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 93
 94    Args:
 95        value: The value of interest.
 96
 97    Returns:
 98        The value if it's a collection, or else the value wrapped in a list.
 99    """
100    if value is None:
101        return []
102    return (
103        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
104    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
107def csv(*args: str, sep: str = ", ") -> str:
108    """
109    Formats any number of string arguments as CSV.
110
111    Args:
112        args: The string arguments to format.
113        sep: The argument separator.
114
115    Returns:
116        The arguments formatted as a CSV string.
117    """
118    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
121def subclasses(
122    module_name: str,
123    classes: t.Type | t.Tuple[t.Type, ...],
124    exclude: t.Type | t.Tuple[t.Type, ...] = (),
125) -> t.List[t.Type]:
126    """
127    Returns all subclasses for a collection of classes, possibly excluding some of them.
128
129    Args:
130        module_name: The name of the module to search for subclasses in.
131        classes: Class(es) we want to find the subclasses of.
132        exclude: Class(es) we want to exclude from the returned list.
133
134    Returns:
135        The target subclasses.
136    """
137    return [
138        obj
139        for _, obj in inspect.getmembers(
140            sys.modules[module_name],
141            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
142        )
143    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
146def apply_index_offset(
147    this: exp.Expression,
148    expressions: t.List[E],
149    offset: int,
150) -> t.List[E]:
151    """
152    Applies an offset to a given integer literal expression.
153
154    Args:
155        this: The target of the index.
156        expressions: The expression the offset will be applied to, wrapped in a list.
157        offset: The offset that will be applied.
158
159    Returns:
160        The original expression with the offset applied to it, wrapped in a list. If the provided
161        `expressions` argument contains more than one expression, it's returned unaffected.
162    """
163    if not offset or len(expressions) != 1:
164        return expressions
165
166    expression = expressions[0]
167
168    from sqlglot import exp
169    from sqlglot.optimizer.annotate_types import annotate_types
170    from sqlglot.optimizer.simplify import simplify
171
172    if not this.type:
173        annotate_types(this)
174
175    if t.cast(exp.DataType, this.type).this not in (
176        exp.DataType.Type.UNKNOWN,
177        exp.DataType.Type.ARRAY,
178    ):
179        return expressions
180
181    if not expression.type:
182        annotate_types(expression)
183    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
184        logger.warning("Applying array index offset (%s)", offset)
185        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
186        return [expression]
187
188    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
191def camel_to_snake_case(name: str) -> str:
192    """Converts `name` from camelCase to snake_case and returns the result."""
193    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
196def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
197    """
198    Applies a transformation to a given expression until a fix point is reached.
199
200    Args:
201        expression: The expression to be transformed.
202        func: The transformation to be applied.
203
204    Returns:
205        The transformed expression.
206    """
207    while True:
208        for n, *_ in reversed(tuple(expression.walk())):
209            n._hash = hash(n)
210
211        start = hash(expression)
212        expression = func(expression)
213
214        for n, *_ in expression.walk():
215            n._hash = None
216        if start == hash(expression):
217            break
218
219    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
222def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
223    """
224    Sorts a given directed acyclic graph in topological order.
225
226    Args:
227        dag: The graph to be sorted.
228
229    Returns:
230        A list that contains all of the graph's nodes in topological order.
231    """
232    result = []
233
234    for node, deps in tuple(dag.items()):
235        for dep in deps:
236            if dep not in dag:
237                dag[dep] = set()
238
239    while dag:
240        current = {node for node, deps in dag.items() if not deps}
241
242        if not current:
243            raise ValueError("Cycle error")
244
245        for node in current:
246            dag.pop(node)
247
248        for deps in dag.values():
249            deps -= current
250
251        result.extend(sorted(current))  # type: ignore
252
253    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
256def open_file(file_name: str) -> t.TextIO:
257    """Open a file that may be compressed as gzip and return it in universal newline mode."""
258    with open(file_name, "rb") as f:
259        gzipped = f.read(2) == b"\x1f\x8b"
260
261    if gzipped:
262        import gzip
263
264        return gzip.open(file_name, "rt", newline="")
265
266    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
269@contextmanager
270def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
271    """
272    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
273
274    Args:
275        read_csv: A `ReadCSV` function call.
276
277    Yields:
278        A python csv reader.
279    """
280    args = read_csv.expressions
281    file = open_file(read_csv.name)
282
283    delimiter = ","
284    args = iter(arg.name for arg in args)  # type: ignore
285    for k, v in zip(args, args):
286        if k == "delimiter":
287            delimiter = v
288
289    try:
290        import csv as csv_
291
292        yield csv_.reader(file, delimiter=delimiter)
293    finally:
294        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
297def find_new_name(taken: t.Collection[str], base: str) -> str:
298    """
299    Searches for a new name.
300
301    Args:
302        taken: A collection of taken names.
303        base: Base name to alter.
304
305    Returns:
306        The new, available name.
307    """
308    if base not in taken:
309        return base
310
311    i = 2
312    new = f"{base}_{i}"
313    while new in taken:
314        i += 1
315        new = f"{base}_{i}"
316
317    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def is_int(text: str) -> bool:
320def is_int(text: str) -> bool:
321    try:
322        int(text)
323        return True
324    except ValueError:
325        return False
def name_sequence(prefix: str) -> Callable[[], str]:
328def name_sequence(prefix: str) -> t.Callable[[], str]:
329    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
330    sequence = count()
331    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
334def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
335    """Returns a dictionary created from an object's attributes."""
336    return {
337        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
338        **kwargs,
339    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
342def split_num_words(
343    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
344) -> t.List[t.Optional[str]]:
345    """
346    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
347
348    Args:
349        value: The value to be split.
350        sep: The value to use to split on.
351        min_num_words: The minimum number of words that are going to be in the result.
352        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
353
354    Examples:
355        >>> split_num_words("db.table", ".", 3)
356        [None, 'db', 'table']
357        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
358        ['db', 'table', None]
359        >>> split_num_words("db.table", ".", 1)
360        ['db', 'table']
361
362    Returns:
363        The list of words returned by `split`, possibly augmented by a number of `None` values.
364    """
365    words = value.split(sep)
366    if fill_from_start:
367        return [None] * (min_num_words - len(words)) + words
368    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
371def is_iterable(value: t.Any) -> bool:
372    """
373    Checks if the value is an iterable, excluding the types `str` and `bytes`.
374
375    Examples:
376        >>> is_iterable([1,2])
377        True
378        >>> is_iterable("test")
379        False
380
381    Args:
382        value: The value to check if it is an iterable.
383
384    Returns:
385        A `bool` value indicating if it is an iterable.
386    """
387    from sqlglot import Expression
388
389    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
392def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
393    """
394    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
395    type `str` and `bytes` are not regarded as iterables.
396
397    Examples:
398        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
399        [1, 2, 3, 4, 5, 'bla']
400        >>> list(flatten([1, 2, 3]))
401        [1, 2, 3]
402
403    Args:
404        values: The value to be flattened.
405
406    Yields:
407        Non-iterable elements in `values`.
408    """
409    for value in values:
410        if is_iterable(value):
411            yield from flatten(value)
412        else:
413            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
416def dict_depth(d: t.Dict) -> int:
417    """
418    Get the nesting depth of a dictionary.
419
420    Example:
421        >>> dict_depth(None)
422        0
423        >>> dict_depth({})
424        1
425        >>> dict_depth({"a": "b"})
426        1
427        >>> dict_depth({"a": {}})
428        2
429        >>> dict_depth({"a": {"b": {}}})
430        3
431    """
432    try:
433        return 1 + dict_depth(next(iter(d.values())))
434    except AttributeError:
435        # d doesn't have attribute "values"
436        return 0
437    except StopIteration:
438        # d.values() returns an empty sequence
439        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
442def first(it: t.Iterable[T]) -> T:
443    """Returns the first element from an iterable (useful for sets)."""
444    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
447def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
448    """
449    Merges a sequence of ranges, represented as tuples (low, high) whose values
450    belong to some totally-ordered set.
451
452    Example:
453        >>> merge_ranges([(1, 3), (2, 6)])
454        [(1, 6)]
455    """
456    if not ranges:
457        return []
458
459    ranges = sorted(ranges)
460
461    merged = [ranges[0]]
462
463    for start, end in ranges[1:]:
464        last_start, last_end = merged[-1]
465
466        if start <= last_end:
467            merged[-1] = (last_start, max(last_end, end))
468        else:
469            merged.append((start, end))
470
471    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
474def is_iso_date(text: str) -> bool:
475    try:
476        datetime.date.fromisoformat(text)
477        return True
478    except ValueError:
479        return False
def is_iso_datetime(text: str) -> bool:
482def is_iso_datetime(text: str) -> bool:
483    try:
484        datetime.datetime.fromisoformat(text)
485        return True
486    except ValueError:
487        return False
DATE_UNITS = {'month', 'year', 'week', 'year_month', 'quarter', 'day'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
494def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
495    return expression is not None and expression.name.lower() in DATE_UNITS
class SingleValuedMapping(typing.Mapping[~K, ~V]):
502class SingleValuedMapping(t.Mapping[K, V]):
503    """
504    Mapping where all keys return the same value.
505
506    This rigamarole is meant to avoid copying keys, which was originally intended
507    as an optimization while qualifying columns for tables with lots of columns.
508    """
509
510    def __init__(self, keys: t.Collection[K], value: V):
511        self._keys = keys if isinstance(keys, Set) else set(keys)
512        self._value = value
513
514    def __getitem__(self, key: K) -> V:
515        if key in self._keys:
516            return self._value
517        raise KeyError(key)
518
519    def __len__(self) -> int:
520        return len(self._keys)
521
522    def __iter__(self) -> t.Iterator[K]:
523        return iter(self._keys)

Mapping where all keys return the same value.

This rigamarole is meant to avoid copying keys, which was originally intended as an optimization while qualifying columns for tables with lots of columns.

SingleValuedMapping(keys: Collection[~K], value: ~V)
510    def __init__(self, keys: t.Collection[K], value: V):
511        self._keys = keys if isinstance(keys, Set) else set(keys)
512        self._value = value
Inherited Members
collections.abc.Mapping
get
keys
items
values