Edit on GitHub

sqlglot.helper

  1from __future__ import annotations
  2
  3import inspect
  4import logging
  5import re
  6import sys
  7import typing as t
  8from collections.abc import Collection
  9from contextlib import contextmanager
 10from copy import copy
 11from enum import Enum
 12from itertools import count
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot import exp
 16    from sqlglot._typing import E, T
 17    from sqlglot.expressions import Expression
 18
 19CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
 20PYTHON_VERSION = sys.version_info[:2]
 21logger = logging.getLogger("sqlglot")
 22
 23
 24class AutoName(Enum):
 25    """
 26    This is used for creating Enum classes where `auto()` is the string form
 27    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
 28
 29    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
 30    """
 31
 32    def _generate_next_value_(name, _start, _count, _last_values):
 33        return name
 34
 35
 36def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
 37    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
 38    try:
 39        return seq[index]
 40    except IndexError:
 41        return None
 42
 43
 44@t.overload
 45def ensure_list(value: t.Collection[T]) -> t.List[T]:
 46    ...
 47
 48
 49@t.overload
 50def ensure_list(value: T) -> t.List[T]:
 51    ...
 52
 53
 54def ensure_list(value):
 55    """
 56    Ensures that a value is a list, otherwise casts or wraps it into one.
 57
 58    Args:
 59        value: The value of interest.
 60
 61    Returns:
 62        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
 63    """
 64    if value is None:
 65        return []
 66    if isinstance(value, (list, tuple)):
 67        return list(value)
 68
 69    return [value]
 70
 71
 72@t.overload
 73def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
 74    ...
 75
 76
 77@t.overload
 78def ensure_collection(value: T) -> t.Collection[T]:
 79    ...
 80
 81
 82def ensure_collection(value):
 83    """
 84    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
 85
 86    Args:
 87        value: The value of interest.
 88
 89    Returns:
 90        The value if it's a collection, or else the value wrapped in a list.
 91    """
 92    if value is None:
 93        return []
 94    return (
 95        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
 96    )
 97
 98
 99def csv(*args: str, sep: str = ", ") -> str:
100    """
101    Formats any number of string arguments as CSV.
102
103    Args:
104        args: The string arguments to format.
105        sep: The argument separator.
106
107    Returns:
108        The arguments formatted as a CSV string.
109    """
110    return sep.join(arg for arg in args if arg)
111
112
113def subclasses(
114    module_name: str,
115    classes: t.Type | t.Tuple[t.Type, ...],
116    exclude: t.Type | t.Tuple[t.Type, ...] = (),
117) -> t.List[t.Type]:
118    """
119    Returns all subclasses for a collection of classes, possibly excluding some of them.
120
121    Args:
122        module_name: The name of the module to search for subclasses in.
123        classes: Class(es) we want to find the subclasses of.
124        exclude: Class(es) we want to exclude from the returned list.
125
126    Returns:
127        The target subclasses.
128    """
129    return [
130        obj
131        for _, obj in inspect.getmembers(
132            sys.modules[module_name],
133            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
134        )
135    ]
136
137
138def apply_index_offset(
139    this: exp.Expression,
140    expressions: t.List[E],
141    offset: int,
142) -> t.List[E]:
143    """
144    Applies an offset to a given integer literal expression.
145
146    Args:
147        this: The target of the index.
148        expressions: The expression the offset will be applied to, wrapped in a list.
149        offset: The offset that will be applied.
150
151    Returns:
152        The original expression with the offset applied to it, wrapped in a list. If the provided
153        `expressions` argument contains more than one expression, it's returned unaffected.
154    """
155    if not offset or len(expressions) != 1:
156        return expressions
157
158    expression = expressions[0]
159
160    from sqlglot import exp
161    from sqlglot.optimizer.annotate_types import annotate_types
162    from sqlglot.optimizer.simplify import simplify
163
164    if not this.type:
165        annotate_types(this)
166
167    if t.cast(exp.DataType, this.type).this not in (
168        exp.DataType.Type.UNKNOWN,
169        exp.DataType.Type.ARRAY,
170    ):
171        return expressions
172
173    if not expression.type:
174        annotate_types(expression)
175    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
176        logger.warning("Applying array index offset (%s)", offset)
177        expression = simplify(
178            exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
179        )
180        return [expression]
181
182    return expressions
183
184
185def camel_to_snake_case(name: str) -> str:
186    """Converts `name` from camelCase to snake_case and returns the result."""
187    return CAMEL_CASE_PATTERN.sub("_", name).upper()
188
189
190def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
191    """
192    Applies a transformation to a given expression until a fix point is reached.
193
194    Args:
195        expression: The expression to be transformed.
196        func: The transformation to be applied.
197
198    Returns:
199        The transformed expression.
200    """
201    while True:
202        for n, *_ in reversed(tuple(expression.walk())):
203            n._hash = hash(n)
204
205        start = hash(expression)
206        expression = func(expression)
207
208        for n, *_ in expression.walk():
209            n._hash = None
210        if start == hash(expression):
211            break
212
213    return expression
214
215
216def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
217    """
218    Sorts a given directed acyclic graph in topological order.
219
220    Args:
221        dag: The graph to be sorted.
222
223    Returns:
224        A list that contains all of the graph's nodes in topological order.
225    """
226    result = []
227
228    for node, deps in tuple(dag.items()):
229        for dep in deps:
230            if not dep in dag:
231                dag[dep] = set()
232
233    while dag:
234        current = {node for node, deps in dag.items() if not deps}
235
236        if not current:
237            raise ValueError("Cycle error")
238
239        for node in current:
240            dag.pop(node)
241
242        for deps in dag.values():
243            deps -= current
244
245        result.extend(sorted(current))  # type: ignore
246
247    return result
248
249
250def open_file(file_name: str) -> t.TextIO:
251    """Open a file that may be compressed as gzip and return it in universal newline mode."""
252    with open(file_name, "rb") as f:
253        gzipped = f.read(2) == b"\x1f\x8b"
254
255    if gzipped:
256        import gzip
257
258        return gzip.open(file_name, "rt", newline="")
259
260    return open(file_name, encoding="utf-8", newline="")
261
262
263@contextmanager
264def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
265    """
266    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
267
268    Args:
269        read_csv: A `ReadCSV` function call.
270
271    Yields:
272        A python csv reader.
273    """
274    args = read_csv.expressions
275    file = open_file(read_csv.name)
276
277    delimiter = ","
278    args = iter(arg.name for arg in args)
279    for k, v in zip(args, args):
280        if k == "delimiter":
281            delimiter = v
282
283    try:
284        import csv as csv_
285
286        yield csv_.reader(file, delimiter=delimiter)
287    finally:
288        file.close()
289
290
291def find_new_name(taken: t.Collection[str], base: str) -> str:
292    """
293    Searches for a new name.
294
295    Args:
296        taken: A collection of taken names.
297        base: Base name to alter.
298
299    Returns:
300        The new, available name.
301    """
302    if base not in taken:
303        return base
304
305    i = 2
306    new = f"{base}_{i}"
307    while new in taken:
308        i += 1
309        new = f"{base}_{i}"
310
311    return new
312
313
314def name_sequence(prefix: str) -> t.Callable[[], str]:
315    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
316    sequence = count()
317    return lambda: f"{prefix}{next(sequence)}"
318
319
320def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
321    """Returns a dictionary created from an object's attributes."""
322    return {
323        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
324        **kwargs,
325    }
326
327
328def split_num_words(
329    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
330) -> t.List[t.Optional[str]]:
331    """
332    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
333
334    Args:
335        value: The value to be split.
336        sep: The value to use to split on.
337        min_num_words: The minimum number of words that are going to be in the result.
338        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
339
340    Examples:
341        >>> split_num_words("db.table", ".", 3)
342        [None, 'db', 'table']
343        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
344        ['db', 'table', None]
345        >>> split_num_words("db.table", ".", 1)
346        ['db', 'table']
347
348    Returns:
349        The list of words returned by `split`, possibly augmented by a number of `None` values.
350    """
351    words = value.split(sep)
352    if fill_from_start:
353        return [None] * (min_num_words - len(words)) + words
354    return words + [None] * (min_num_words - len(words))
355
356
357def is_iterable(value: t.Any) -> bool:
358    """
359    Checks if the value is an iterable, excluding the types `str` and `bytes`.
360
361    Examples:
362        >>> is_iterable([1,2])
363        True
364        >>> is_iterable("test")
365        False
366
367    Args:
368        value: The value to check if it is an iterable.
369
370    Returns:
371        A `bool` value indicating if it is an iterable.
372    """
373    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
374
375
376def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
377    """
378    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
379    type `str` and `bytes` are not regarded as iterables.
380
381    Examples:
382        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
383        [1, 2, 3, 4, 5, 'bla']
384        >>> list(flatten([1, 2, 3]))
385        [1, 2, 3]
386
387    Args:
388        values: The value to be flattened.
389
390    Yields:
391        Non-iterable elements in `values`.
392    """
393    for value in values:
394        if is_iterable(value):
395            yield from flatten(value)
396        else:
397            yield value
398
399
400def dict_depth(d: t.Dict) -> int:
401    """
402    Get the nesting depth of a dictionary.
403
404    Example:
405        >>> dict_depth(None)
406        0
407        >>> dict_depth({})
408        1
409        >>> dict_depth({"a": "b"})
410        1
411        >>> dict_depth({"a": {}})
412        2
413        >>> dict_depth({"a": {"b": {}}})
414        3
415    """
416    try:
417        return 1 + dict_depth(next(iter(d.values())))
418    except AttributeError:
419        # d doesn't have attribute "values"
420        return 0
421    except StopIteration:
422        # d.values() returns an empty sequence
423        return 1
424
425
426def first(it: t.Iterable[T]) -> T:
427    """Returns the first element from an iterable (useful for sets)."""
428    return next(i for i in it)
CAMEL_CASE_PATTERN = re.compile('(?<!^)(?=[A-Z])')
PYTHON_VERSION = (3, 10)
logger = <Logger sqlglot (WARNING)>
class AutoName(enum.Enum):
25class AutoName(Enum):
26    """
27    This is used for creating Enum classes where `auto()` is the string form
28    of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
29
30    Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
31    """
32
33    def _generate_next_value_(name, _start, _count, _last_values):
34        return name

This is used for creating Enum classes where auto() is the string form of the corresponding enum's identifier (e.g. FOO.value results in "FOO").

Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values

Inherited Members
enum.Enum
name
value
def seq_get(seq: Sequence[~T], index: int) -> Optional[~T]:
37def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
38    """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
39    try:
40        return seq[index]
41    except IndexError:
42        return None

Returns the value in seq at position index, or None if index is out of bounds.

def ensure_list(value):
55def ensure_list(value):
56    """
57    Ensures that a value is a list, otherwise casts or wraps it into one.
58
59    Args:
60        value: The value of interest.
61
62    Returns:
63        The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
64    """
65    if value is None:
66        return []
67    if isinstance(value, (list, tuple)):
68        return list(value)
69
70    return [value]

Ensures that a value is a list, otherwise casts or wraps it into one.

Arguments:
  • value: The value of interest.
Returns:

The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.

def ensure_collection(value):
83def ensure_collection(value):
84    """
85    Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
86
87    Args:
88        value: The value of interest.
89
90    Returns:
91        The value if it's a collection, or else the value wrapped in a list.
92    """
93    if value is None:
94        return []
95    return (
96        value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
97    )

Ensures that a value is a collection (excluding str and bytes), otherwise wraps it into a list.

Arguments:
  • value: The value of interest.
Returns:

The value if it's a collection, or else the value wrapped in a list.

def csv(*args: str, sep: str = ', ') -> str:
100def csv(*args: str, sep: str = ", ") -> str:
101    """
102    Formats any number of string arguments as CSV.
103
104    Args:
105        args: The string arguments to format.
106        sep: The argument separator.
107
108    Returns:
109        The arguments formatted as a CSV string.
110    """
111    return sep.join(arg for arg in args if arg)

Formats any number of string arguments as CSV.

Arguments:
  • args: The string arguments to format.
  • sep: The argument separator.
Returns:

The arguments formatted as a CSV string.

def subclasses( module_name: str, classes: Union[Type, Tuple[Type, ...]], exclude: Union[Type, Tuple[Type, ...]] = ()) -> List[Type]:
114def subclasses(
115    module_name: str,
116    classes: t.Type | t.Tuple[t.Type, ...],
117    exclude: t.Type | t.Tuple[t.Type, ...] = (),
118) -> t.List[t.Type]:
119    """
120    Returns all subclasses for a collection of classes, possibly excluding some of them.
121
122    Args:
123        module_name: The name of the module to search for subclasses in.
124        classes: Class(es) we want to find the subclasses of.
125        exclude: Class(es) we want to exclude from the returned list.
126
127    Returns:
128        The target subclasses.
129    """
130    return [
131        obj
132        for _, obj in inspect.getmembers(
133            sys.modules[module_name],
134            lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
135        )
136    ]

Returns all subclasses for a collection of classes, possibly excluding some of them.

Arguments:
  • module_name: The name of the module to search for subclasses in.
  • classes: Class(es) we want to find the subclasses of.
  • exclude: Class(es) we want to exclude from the returned list.
Returns:

The target subclasses.

def apply_index_offset( this: sqlglot.expressions.Expression, expressions: List[~E], offset: int) -> List[~E]:
139def apply_index_offset(
140    this: exp.Expression,
141    expressions: t.List[E],
142    offset: int,
143) -> t.List[E]:
144    """
145    Applies an offset to a given integer literal expression.
146
147    Args:
148        this: The target of the index.
149        expressions: The expression the offset will be applied to, wrapped in a list.
150        offset: The offset that will be applied.
151
152    Returns:
153        The original expression with the offset applied to it, wrapped in a list. If the provided
154        `expressions` argument contains more than one expression, it's returned unaffected.
155    """
156    if not offset or len(expressions) != 1:
157        return expressions
158
159    expression = expressions[0]
160
161    from sqlglot import exp
162    from sqlglot.optimizer.annotate_types import annotate_types
163    from sqlglot.optimizer.simplify import simplify
164
165    if not this.type:
166        annotate_types(this)
167
168    if t.cast(exp.DataType, this.type).this not in (
169        exp.DataType.Type.UNKNOWN,
170        exp.DataType.Type.ARRAY,
171    ):
172        return expressions
173
174    if not expression.type:
175        annotate_types(expression)
176    if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
177        logger.warning("Applying array index offset (%s)", offset)
178        expression = simplify(
179            exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
180        )
181        return [expression]
182
183    return expressions

Applies an offset to a given integer literal expression.

Arguments:
  • this: The target of the index.
  • expressions: The expression the offset will be applied to, wrapped in a list.
  • offset: The offset that will be applied.
Returns:

The original expression with the offset applied to it, wrapped in a list. If the provided expressions argument contains more than one expression, it's returned unaffected.

def camel_to_snake_case(name: str) -> str:
186def camel_to_snake_case(name: str) -> str:
187    """Converts `name` from camelCase to snake_case and returns the result."""
188    return CAMEL_CASE_PATTERN.sub("_", name).upper()

Converts name from camelCase to snake_case and returns the result.

def while_changing( expression: sqlglot.expressions.Expression, func: Callable[[sqlglot.expressions.Expression], ~E]) -> ~E:
191def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
192    """
193    Applies a transformation to a given expression until a fix point is reached.
194
195    Args:
196        expression: The expression to be transformed.
197        func: The transformation to be applied.
198
199    Returns:
200        The transformed expression.
201    """
202    while True:
203        for n, *_ in reversed(tuple(expression.walk())):
204            n._hash = hash(n)
205
206        start = hash(expression)
207        expression = func(expression)
208
209        for n, *_ in expression.walk():
210            n._hash = None
211        if start == hash(expression):
212            break
213
214    return expression

Applies a transformation to a given expression until a fix point is reached.

Arguments:
  • expression: The expression to be transformed.
  • func: The transformation to be applied.
Returns:

The transformed expression.

def tsort(dag: Dict[~T, Set[~T]]) -> List[~T]:
217def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
218    """
219    Sorts a given directed acyclic graph in topological order.
220
221    Args:
222        dag: The graph to be sorted.
223
224    Returns:
225        A list that contains all of the graph's nodes in topological order.
226    """
227    result = []
228
229    for node, deps in tuple(dag.items()):
230        for dep in deps:
231            if not dep in dag:
232                dag[dep] = set()
233
234    while dag:
235        current = {node for node, deps in dag.items() if not deps}
236
237        if not current:
238            raise ValueError("Cycle error")
239
240        for node in current:
241            dag.pop(node)
242
243        for deps in dag.values():
244            deps -= current
245
246        result.extend(sorted(current))  # type: ignore
247
248    return result

Sorts a given directed acyclic graph in topological order.

Arguments:
  • dag: The graph to be sorted.
Returns:

A list that contains all of the graph's nodes in topological order.

def open_file(file_name: str) -> <class 'TextIO'>:
251def open_file(file_name: str) -> t.TextIO:
252    """Open a file that may be compressed as gzip and return it in universal newline mode."""
253    with open(file_name, "rb") as f:
254        gzipped = f.read(2) == b"\x1f\x8b"
255
256    if gzipped:
257        import gzip
258
259        return gzip.open(file_name, "rt", newline="")
260
261    return open(file_name, encoding="utf-8", newline="")

Open a file that may be compressed as gzip and return it in universal newline mode.

@contextmanager
def csv_reader(read_csv: sqlglot.expressions.ReadCSV) -> Any:
264@contextmanager
265def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
266    """
267    Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
268
269    Args:
270        read_csv: A `ReadCSV` function call.
271
272    Yields:
273        A python csv reader.
274    """
275    args = read_csv.expressions
276    file = open_file(read_csv.name)
277
278    delimiter = ","
279    args = iter(arg.name for arg in args)
280    for k, v in zip(args, args):
281        if k == "delimiter":
282            delimiter = v
283
284    try:
285        import csv as csv_
286
287        yield csv_.reader(file, delimiter=delimiter)
288    finally:
289        file.close()

Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]).

Arguments:
  • read_csv: A ReadCSV function call.
Yields:

A python csv reader.

def find_new_name(taken: Collection[str], base: str) -> str:
292def find_new_name(taken: t.Collection[str], base: str) -> str:
293    """
294    Searches for a new name.
295
296    Args:
297        taken: A collection of taken names.
298        base: Base name to alter.
299
300    Returns:
301        The new, available name.
302    """
303    if base not in taken:
304        return base
305
306    i = 2
307    new = f"{base}_{i}"
308    while new in taken:
309        i += 1
310        new = f"{base}_{i}"
311
312    return new

Searches for a new name.

Arguments:
  • taken: A collection of taken names.
  • base: Base name to alter.
Returns:

The new, available name.

def name_sequence(prefix: str) -> Callable[[], str]:
315def name_sequence(prefix: str) -> t.Callable[[], str]:
316    """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
317    sequence = count()
318    return lambda: f"{prefix}{next(sequence)}"

Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").

def object_to_dict(obj: Any, **kwargs) -> Dict:
321def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
322    """Returns a dictionary created from an object's attributes."""
323    return {
324        **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
325        **kwargs,
326    }

Returns a dictionary created from an object's attributes.

def split_num_words( value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> List[Optional[str]]:
329def split_num_words(
330    value: str, sep: str, min_num_words: int, fill_from_start: bool = True
331) -> t.List[t.Optional[str]]:
332    """
333    Perform a split on a value and return N words as a result with `None` used for words that don't exist.
334
335    Args:
336        value: The value to be split.
337        sep: The value to use to split on.
338        min_num_words: The minimum number of words that are going to be in the result.
339        fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
340
341    Examples:
342        >>> split_num_words("db.table", ".", 3)
343        [None, 'db', 'table']
344        >>> split_num_words("db.table", ".", 3, fill_from_start=False)
345        ['db', 'table', None]
346        >>> split_num_words("db.table", ".", 1)
347        ['db', 'table']
348
349    Returns:
350        The list of words returned by `split`, possibly augmented by a number of `None` values.
351    """
352    words = value.split(sep)
353    if fill_from_start:
354        return [None] * (min_num_words - len(words)) + words
355    return words + [None] * (min_num_words - len(words))

Perform a split on a value and return N words as a result with None used for words that don't exist.

Arguments:
  • value: The value to be split.
  • sep: The value to use to split on.
  • min_num_words: The minimum number of words that are going to be in the result.
  • fill_from_start: Indicates that if None values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:

The list of words returned by split, possibly augmented by a number of None values.

def is_iterable(value: Any) -> bool:
358def is_iterable(value: t.Any) -> bool:
359    """
360    Checks if the value is an iterable, excluding the types `str` and `bytes`.
361
362    Examples:
363        >>> is_iterable([1,2])
364        True
365        >>> is_iterable("test")
366        False
367
368    Args:
369        value: The value to check if it is an iterable.
370
371    Returns:
372        A `bool` value indicating if it is an iterable.
373    """
374    return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))

Checks if the value is an iterable, excluding the types str and bytes.

Examples:
>>> is_iterable([1,2])
True
>>> is_iterable("test")
False
Arguments:
  • value: The value to check if it is an iterable.
Returns:

A bool value indicating if it is an iterable.

def flatten(values: Iterable[Union[Iterable[Any], Any]]) -> Iterator[Any]:
377def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
378    """
379    Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
380    type `str` and `bytes` are not regarded as iterables.
381
382    Examples:
383        >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
384        [1, 2, 3, 4, 5, 'bla']
385        >>> list(flatten([1, 2, 3]))
386        [1, 2, 3]
387
388    Args:
389        values: The value to be flattened.
390
391    Yields:
392        Non-iterable elements in `values`.
393    """
394    for value in values:
395        if is_iterable(value):
396            yield from flatten(value)
397        else:
398            yield value

Flattens an iterable that can contain both iterable and non-iterable elements. Objects of type str and bytes are not regarded as iterables.

Examples:
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Arguments:
  • values: The value to be flattened.
Yields:

Non-iterable elements in values.

def dict_depth(d: Dict) -> int:
401def dict_depth(d: t.Dict) -> int:
402    """
403    Get the nesting depth of a dictionary.
404
405    Example:
406        >>> dict_depth(None)
407        0
408        >>> dict_depth({})
409        1
410        >>> dict_depth({"a": "b"})
411        1
412        >>> dict_depth({"a": {}})
413        2
414        >>> dict_depth({"a": {"b": {}}})
415        3
416    """
417    try:
418        return 1 + dict_depth(next(iter(d.values())))
419    except AttributeError:
420        # d doesn't have attribute "values"
421        return 0
422    except StopIteration:
423        # d.values() returns an empty sequence
424        return 1

Get the nesting depth of a dictionary.

Example:
>>> dict_depth(None)
0
>>> dict_depth({})
1
>>> dict_depth({"a": "b"})
1
>>> dict_depth({"a": {}})
2
>>> dict_depth({"a": {"b": {}}})
3
def first(it: Iterable[~T]) -> ~T:
427def first(it: t.Iterable[T]) -> T:
428    """Returns the first element from an iterable (useful for sets)."""
429    return next(i for i in it)

Returns the first element from an iterable (useful for sets).