Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import datetime
  4import inspect
  5import logging
  6import re
  7import sys
  8import typing as t
  9from collections.abc import Collection
 10from contextlib import contextmanager
 11from copy import copy
 12from enum import Enum
 13from itertools import count
 14
 15if t.TYPE_CHECKING:
 16    from sqlglot import exp
 17    from sqlglot._typing import A, E, T
 18    from sqlglot.expressions import Expression
 19
 20
 21CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 22PYTHON_VERSION = sys.version_info[:2]
 23logger = logging.getLogger("sqlglot")
 24
 25
 26class AutoName(Enum):
 27    """
 28    This is used for creating Enum classes where `auto()` is the string form
 29    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 30
 31    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 32    """
 33
 34    def _generate_next_value_(name, _start, _count, _last_values):
 35        return name
 36
 37
 38class classproperty(property):
 39    """
 40    Similar to a normal property but works for class methods
 41    """
 42
 43    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
 44        return classmethod(self.fget).__get__(None, owner)()  # type: ignore
 45
 46
 47def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 48    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 49    try:
 50        return seq[index]
 51    except IndexError:
 52        return None
 53
 54
 55@t.overload
 56def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
 57
 58
 59@t.overload
 60def ensure_list(value: T) -> t.List[T]: ...
 61
 62
 63def ensure_list(value):
 64    """
 65    Ensures that a value is a list, otherwise casts or wraps it into one.
 66
 67    Args:
 68        value: The value of interest.
 69
 70    Returns:
 71        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 72    """
 73    if value is None:
 74        return []
 75    if isinstance(value, (list, tuple)):
 76        return list(value)
 77
 78    return [value]
 79
 80
 81@t.overload
 82def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
 83
 84
 85@t.overload
 86def ensure_collection(value: T) -> t.Collection[T]: ...
 87
 88
 89def ensure_collection(value):
 90    """
 91    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 92
 93    Args:
 94        value: The value of interest.
 95
 96    Returns:
 97        The value if it's a collection, or else the value wrapped in a list.
 98    """
 99    if value is None:
100        return []
101    return (
102        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
103    )
104
105
106def csv(*args: str, sep: str = ", ") -> str:
107    """
108    Formats any number of string arguments as CSV.
109
110    Args:
111        args: The string arguments to format.
112        sep: The argument separator.
113
114    Returns:
115        The arguments formatted as a CSV string.
116    """
117    return sep.join(arg for arg in args if arg)
118
119
120def subclasses(
121    module_name: str,
122    classes: t.Type | t.Tuple[t.Type, ...],
123    exclude: t.Type | t.Tuple[t.Type, ...] = (),
124) -> t.List[t.Type]:
125    """
126    Returns all subclasses for a collection of classes, possibly excluding some of them.
127
128    Args:
129        module_name: The name of the module to search for subclasses in.
130        classes: Class(es) we want to find the subclasses of.
131        exclude: Class(es) we want to exclude from the returned list.
132
133    Returns:
134        The target subclasses.
135    """
136    return [
137        obj
138        for _, obj in inspect.getmembers(
139            sys.modules[module_name],
140            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
141        )
142    ]
143
144
145def apply_index_offset(
146    this: exp.Expression,
147    expressions: t.List[E],
148    offset: int,
149) -> t.List[E]:
150    """
151    Applies an offset to a given integer literal expression.
152
153    Args:
154        this: The target of the index.
155        expressions: The expression the offset will be applied to, wrapped in a list.
156        offset: The offset that will be applied.
157
158    Returns:
159        The original expression with the offset applied to it, wrapped in a list. If the provided
160        `expressions` argument contains more than one expression, it's returned unaffected.
161    """
162    if not offset or len(expressions) != 1:
163        return expressions
164
165    expression = expressions[0]
166
167    from sqlglot import exp
168    from sqlglot.optimizer.annotate_types import annotate_types
169    from sqlglot.optimizer.simplify import simplify
170
171    if not this.type:
172        annotate_types(this)
173
174    if t.cast(exp.DataType, this.type).this not in (
175        exp.DataType.Type.UNKNOWN,
176        exp.DataType.Type.ARRAY,
177    ):
178        return expressions
179
180    if not expression.type:
181        annotate_types(expression)
182    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
183        logger.warning("Applying array index offset (%s)", offset)
184        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
185        return [expression]
186
187    return expressions
188
189
190def camel_to_snake_case(name: str) -> str:
191    """Converts `name` from camelCase to snake_case and returns the result."""
192    return CAMEL_CASE_PATTERN.sub("_", name).upper()
193
194
195def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
196    """
197    Applies a transformation to a given expression until a fix point is reached.
198
199    Args:
200        expression: The expression to be transformed.
201        func: The transformation to be applied.
202
203    Returns:
204        The transformed expression.
205    """
206    while True:
207        for n, *_ in reversed(tuple(expression.walk())):
208            n._hash = hash(n)
209
210        start = hash(expression)
211        expression = func(expression)
212
213        for n, *_ in expression.walk():
214            n._hash = None
215        if start == hash(expression):
216            break
217
218    return expression
219
220
221def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
222    """
223    Sorts a given directed acyclic graph in topological order.
224
225    Args:
226        dag: The graph to be sorted.
227
228    Returns:
229        A list that contains all of the graph's nodes in topological order.
230    """
231    result = []
232
233    for node, deps in tuple(dag.items()):
234        for dep in deps:
235            if not dep in dag:
236                dag[dep] = set()
237
238    while dag:
239        current = {node for node, deps in dag.items() if not deps}
240
241        if not current:
242            raise ValueError("Cycle error")
243
244        for node in current:
245            dag.pop(node)
246
247        for deps in dag.values():
248            deps -= current
249
250        result.extend(sorted(current))  # type: ignore
251
252    return result
253
254
255def open_file(file_name: str) -> t.TextIO:
256    """Open a file that may be compressed as gzip and return it in universal newline mode."""
257    with open(file_name, "rb") as f:
258        gzipped = f.read(2) == b"\x1f\x8b"
259
260    if gzipped:
261        import gzip
262
263        return gzip.open(file_name, "rt", newline="")
264
265    return open(file_name, encoding="utf-8", newline="")
266
267
268@contextmanager
269def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
270    """
271    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
272
273    Args:
274        read_csv: A `ReadCSV` function call.
275
276    Yields:
277        A python csv reader.
278    """
279    args = read_csv.expressions
280    file = open_file(read_csv.name)
281
282    delimiter = ","
283    args = iter(arg.name for arg in args)  # type: ignore
284    for k, v in zip(args, args):
285        if k == "delimiter":
286            delimiter = v
287
288    try:
289        import csv as csv_
290
291        yield csv_.reader(file, delimiter=delimiter)
292    finally:
293        file.close()
294
295
296def find_new_name(taken: t.Collection[str], base: str) -> str:
297    """
298    Searches for a new name.
299
300    Args:
301        taken: A collection of taken names.
302        base: Base name to alter.
303
304    Returns:
305        The new, available name.
306    """
307    if base not in taken:
308        return base
309
310    i = 2
311    new = f"{base}_{i}"
312    while new in taken:
313        i += 1
314        new = f"{base}_{i}"
315
316    return new
317
318
319def name_sequence(prefix: str) -> t.Callable[[], str]:
320    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
321    sequence = count()
322    return lambda: f"{prefix}{next(sequence)}"
323
324
325def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
326    """Returns a dictionary created from an object's attributes."""
327    return {
328        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
329        **kwargs,
330    }
331
332
333def split_num_words(
334    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
335) -> t.List[t.Optional[str]]:
336    """
337    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
338
339    Args:
340        value: The value to be split.
341        sep: The value to use to split on.
342        min_num_words: The minimum number of words that are going to be in the result.
343        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
344
345    Examples:
346        >>> split_num_words("db.table", ".", 3)
347        [None, 'db', 'table']
348        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
349        ['db', 'table', None]
350        >>> split_num_words("db.table", ".", 1)
351        ['db', 'table']
352
353    Returns:
354        The list of words returned by `split`, possibly augmented by a number of `None` values.
355    """
356    words = value.split(sep)
357    if fill_from_start:
358        return [None] * (min_num_words - len(words)) + words
359    return words + [None] * (min_num_words - len(words))
360
361
362def is_iterable(value: t.Any) -> bool:
363    """
364    Checks if the value is an iterable, excluding the types `str` and `bytes`.
365
366    Examples:
367        >>> is_iterable([1,2])
368        True
369        >>> is_iterable("test")
370        False
371
372    Args:
373        value: The value to check if it is an iterable.
374
375    Returns:
376        A `bool` value indicating if it is an iterable.
377    """
378    from sqlglot import Expression
379
380    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
381
382
383def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
384    """
385    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
386    type `str` and `bytes` are not regarded as iterables.
387
388    Examples:
389        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
390        [1, 2, 3, 4, 5, 'bla']
391        >>> list(flatten([1, 2, 3]))
392        [1, 2, 3]
393
394    Args:
395        values: The value to be flattened.
396
397    Yields:
398        Non-iterable elements in `values`.
399    """
400    for value in values:
401        if is_iterable(value):
402            yield from flatten(value)
403        else:
404            yield value
405
406
407def dict_depth(d: t.Dict) -> int:
408    """
409    Get the nesting depth of a dictionary.
410
411    Example:
412        >>> dict_depth(None)
413        0
414        >>> dict_depth({})
415        1
416        >>> dict_depth({"a": "b"})
417        1
418        >>> dict_depth({"a": {}})
419        2
420        >>> dict_depth({"a": {"b": {}}})
421        3
422    """
423    try:
424        return 1 + dict_depth(next(iter(d.values())))
425    except AttributeError:
426        # d doesn't have attribute "values"
427        return 0
428    except StopIteration:
429        # d.values() returns an empty sequence
430        return 1
431
432
433def first(it: t.Iterable[T]) -> T:
434    """Returns the first element from an iterable (useful for sets)."""
435    return next(i for i in it)
436
437
438def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
439    """
440    Merges a sequence of ranges, represented as tuples (low, high) whose values
441    belong to some totally-ordered set.
442
443    Example:
444        >>> merge_ranges([(1, 3), (2, 6)])
445        [(1, 6)]
446    """
447    if not ranges:
448        return []
449
450    ranges = sorted(ranges)
451
452    merged = [ranges[0]]
453
454    for start, end in ranges[1:]:
455        last_start, last_end = merged[-1]
456
457        if start <= last_end:
458            merged[-1] = (last_start, max(last_end, end))
459        else:
460            merged.append((start, end))
461
462    return merged
463
464
465def is_iso_date(text: str) -> bool:
466    try:
467        datetime.date.fromisoformat(text)
468        return True
469    except ValueError:
470        return False
471
472
473def is_iso_datetime(text: str) -> bool:
474    try:
475        datetime.datetime.fromisoformat(text)
476        return True
477    except ValueError:
478        return False
479
480
481# Interval units that operate on date components
482DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
483
484
485def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
486    return expression is not None and expression.name.lower() in DATE_UNITS
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
27class AutoName(Enum):
28    """
29    This is used for creating Enum classes where `auto()` is the string form
30    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
31
32    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
33    """
34
35    def _generate_next_value_(name, _start, _count, _last_values):
36        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
class classproperty(builtins.property):
39class classproperty(property):
40    """
41    Similar to a normal property but works for class methods
42    """
43
44    def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
45        return classmethod(self.fget).__get__(None, owner)()  # type: ignore

Similar to a normal property but works for class methods

Inherited Members
builtins.property
property
getter
setter
deleter
fget
fset
fdel
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
48def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
49    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
50    try:
51        return seq[index]
52    except IndexError:
53        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
64def ensure_list(value):
65    """
66    Ensures that a value is a list, otherwise casts or wraps it into one.
67
68    Args:
69        value: The value of interest.
70
71    Returns:
72        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
73    """
74    if value is None:
75        return []
76    if isinstance(value, (list, tuple)):
77        return list(value)
78
79    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
 90def ensure_collection(value):
 91    """
 92    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 93
 94    Args:
 95        value: The value of interest.
 96
 97    Returns:
 98        The value if it's a collection, or else the value wrapped in a list.
 99    """
100    if value is None:
101        return []
102    return (
103        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
104    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
107def csv(*args: str, sep: str = ", ") -> str:
108    """
109    Formats any number of string arguments as CSV.
110
111    Args:
112        args: The string arguments to format.
113        sep: The argument separator.
114
115    Returns:
116        The arguments formatted as a CSV string.
117    """
118    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
121def subclasses(
122    module_name: str,
123    classes: t.Type | t.Tuple[t.Type, ...],
124    exclude: t.Type | t.Tuple[t.Type, ...] = (),
125) -> t.List[t.Type]:
126    """
127    Returns all subclasses for a collection of classes, possibly excluding some of them.
128
129    Args:
130        module_name: The name of the module to search for subclasses in.
131        classes: Class(es) we want to find the subclasses of.
132        exclude: Class(es) we want to exclude from the returned list.
133
134    Returns:
135        The target subclasses.
136    """
137    return [
138        obj
139        for _, obj in inspect.getmembers(
140            sys.modules[module_name],
141            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
142        )
143    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
146def apply_index_offset(
147    this: exp.Expression,
148    expressions: t.List[E],
149    offset: int,
150) -> t.List[E]:
151    """
152    Applies an offset to a given integer literal expression.
153
154    Args:
155        this: The target of the index.
156        expressions: The expression the offset will be applied to, wrapped in a list.
157        offset: The offset that will be applied.
158
159    Returns:
160        The original expression with the offset applied to it, wrapped in a list. If the provided
161        `expressions` argument contains more than one expression, it's returned unaffected.
162    """
163    if not offset or len(expressions) != 1:
164        return expressions
165
166    expression = expressions[0]
167
168    from sqlglot import exp
169    from sqlglot.optimizer.annotate_types import annotate_types
170    from sqlglot.optimizer.simplify import simplify
171
172    if not this.type:
173        annotate_types(this)
174
175    if t.cast(exp.DataType, this.type).this not in (
176        exp.DataType.Type.UNKNOWN,
177        exp.DataType.Type.ARRAY,
178    ):
179        return expressions
180
181    if not expression.type:
182        annotate_types(expression)
183    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
184        logger.warning("Applying array index offset (%s)", offset)
185        expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
186        return [expression]
187
188    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
191def camel_to_snake_case(name: str) -> str:
192    """Converts `name` from camelCase to snake_case and returns the result."""
193    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
196def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
197    """
198    Applies a transformation to a given expression until a fix point is reached.
199
200    Args:
201        expression: The expression to be transformed.
202        func: The transformation to be applied.
203
204    Returns:
205        The transformed expression.
206    """
207    while True:
208        for n, *_ in reversed(tuple(expression.walk())):
209            n._hash = hash(n)
210
211        start = hash(expression)
212        expression = func(expression)
213
214        for n, *_ in expression.walk():
215            n._hash = None
216        if start == hash(expression):
217            break
218
219    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
222def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
223    """
224    Sorts a given directed acyclic graph in topological order.
225
226    Args:
227        dag: The graph to be sorted.
228
229    Returns:
230        A list that contains all of the graph's nodes in topological order.
231    """
232    result = []
233
234    for node, deps in tuple(dag.items()):
235        for dep in deps:
236            if not dep in dag:
237                dag[dep] = set()
238
239    while dag:
240        current = {node for node, deps in dag.items() if not deps}
241
242        if not current:
243            raise ValueError("Cycle error")
244
245        for node in current:
246            dag.pop(node)
247
248        for deps in dag.values():
249            deps -= current
250
251        result.extend(sorted(current))  # type: ignore
252
253    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
256def open_file(file_name: str) -> t.TextIO:
257    """Open a file that may be compressed as gzip and return it in universal newline mode."""
258    with open(file_name, "rb") as f:
259        gzipped = f.read(2) == b"\x1f\x8b"
260
261    if gzipped:
262        import gzip
263
264        return gzip.open(file_name, "rt", newline="")
265
266    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
269@contextmanager
270def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
271    """
272    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
273
274    Args:
275        read_csv: A `ReadCSV` function call.
276
277    Yields:
278        A python csv reader.
279    """
280    args = read_csv.expressions
281    file = open_file(read_csv.name)
282
283    delimiter = ","
284    args = iter(arg.name for arg in args)  # type: ignore
285    for k, v in zip(args, args):
286        if k == "delimiter":
287            delimiter = v
288
289    try:
290        import csv as csv_
291
292        yield csv_.reader(file, delimiter=delimiter)
293    finally:
294        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
297def find_new_name(taken: t.Collection[str], base: str) -> str:
298    """
299    Searches for a new name.
300
301    Args:
302        taken: A collection of taken names.
303        base: Base name to alter.
304
305    Returns:
306        The new, available name.
307    """
308    if base not in taken:
309        return base
310
311    i = 2
312    new = f"{base}_{i}"
313    while new in taken:
314        i += 1
315        new = f"{base}_{i}"
316
317    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def name_sequence(prefix: str) -> Callable[[], str]:
320def name_sequence(prefix: str) -> t.Callable[[], str]:
321    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
322    sequence = count()
323    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
326def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
327    """Returns a dictionary created from an object's attributes."""
328    return {
329        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
330        **kwargs,
331    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
334def split_num_words(
335    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
336) -> t.List[t.Optional[str]]:
337    """
338    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
339
340    Args:
341        value: The value to be split.
342        sep: The value to use to split on.
343        min_num_words: The minimum number of words that are going to be in the result.
344        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
345
346    Examples:
347        >>> split_num_words("db.table", ".", 3)
348        [None, 'db', 'table']
349        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
350        ['db', 'table', None]
351        >>> split_num_words("db.table", ".", 1)
352        ['db', 'table']
353
354    Returns:
355        The list of words returned by `split`, possibly augmented by a number of `None` values.
356    """
357    words = value.split(sep)
358    if fill_from_start:
359        return [None] * (min_num_words - len(words)) + words
360    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
363def is_iterable(value: t.Any) -> bool:
364    """
365    Checks if the value is an iterable, excluding the types `str` and `bytes`.
366
367    Examples:
368        >>> is_iterable([1,2])
369        True
370        >>> is_iterable("test")
371        False
372
373    Args:
374        value: The value to check if it is an iterable.
375
376    Returns:
377        A `bool` value indicating if it is an iterable.
378    """
379    from sqlglot import Expression
380
381    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
384def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
385    """
386    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
387    type `str` and `bytes` are not regarded as iterables.
388
389    Examples:
390        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
391        [1, 2, 3, 4, 5, 'bla']
392        >>> list(flatten([1, 2, 3]))
393        [1, 2, 3]
394
395    Args:
396        values: The value to be flattened.
397
398    Yields:
399        Non-iterable elements in `values`.
400    """
401    for value in values:
402        if is_iterable(value):
403            yield from flatten(value)
404        else:
405            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
408def dict_depth(d: t.Dict) -> int:
409    """
410    Get the nesting depth of a dictionary.
411
412    Example:
413        >>> dict_depth(None)
414        0
415        >>> dict_depth({})
416        1
417        >>> dict_depth({"a": "b"})
418        1
419        >>> dict_depth({"a": {}})
420        2
421        >>> dict_depth({"a": {"b": {}}})
422        3
423    """
424    try:
425        return 1 + dict_depth(next(iter(d.values())))
426    except AttributeError:
427        # d doesn't have attribute "values"
428        return 0
429    except StopIteration:
430        # d.values() returns an empty sequence
431        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
434def first(it: t.Iterable[T]) -> T:
435    """Returns the first element from an iterable (useful for sets)."""
436    return next(i for i in it)

Returns the first element from an iterable (useful for sets).

def merge_ranges(ranges: List[Tuple[~A, ~A]]) -> List[Tuple[~A, ~A]]:
439def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
440    """
441    Merges a sequence of ranges, represented as tuples (low, high) whose values
442    belong to some totally-ordered set.
443
444    Example:
445        >>> merge_ranges([(1, 3), (2, 6)])
446        [(1, 6)]
447    """
448    if not ranges:
449        return []
450
451    ranges = sorted(ranges)
452
453    merged = [ranges[0]]
454
455    for start, end in ranges[1:]:
456        last_start, last_end = merged[-1]
457
458        if start <= last_end:
459            merged[-1] = (last_start, max(last_end, end))
460        else:
461            merged.append((start, end))
462
463    return merged

Merges a sequence of ranges, represented as tuples (low, high) whose values belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
def is_iso_date(text: str) -> bool:
466def is_iso_date(text: str) -> bool:
467    try:
468        datetime.date.fromisoformat(text)
469        return True
470    except ValueError:
471        return False
def is_iso_datetime(text: str) -> bool:
474def is_iso_datetime(text: str) -> bool:
475    try:
476        datetime.datetime.fromisoformat(text)
477        return True
478    except ValueError:
479        return False
DATE_UNITS = {'month', 'quarter', 'year', 'week', 'day', 'year_month'}
def is_date_unit(expression: Optional[sqlglot.expressions.Expression]) -> bool:
486def is_date_unit(expression: t.Optional[exp.Expression]) -> bool:
487    return expression is not None and expression.name.lower() in DATE_UNITS