from itertools import repeat, count
from typing import Iterable, Optional, Any, Coroutine, Callable
import trio
from tqdm.auto import tqdm
from many_requests.common import N_WORKERS_DEFAULT, is_collection
[docs]def delayed(func: Callable[..., Coroutine]):
"""
Decorator used to capture an async function with arguments and delay its execution.
Examples:
>>> import asks
>>> func_, args_, kwargs_ = delayed(asks.request)('GET', url='https://example.org')
>>> assert func_ == asks.request
>>> assert args_ == ('GET',)
>>> assert kwargs_ == {'url': 'https://example.org'}
"""
def delayed_func(*args, **kwargs):
return func, args, kwargs
return delayed_func
[docs]def zip_kw(**kwargs):
"""
Return an iterator of N dictionaries, where the Nth item of each iterator specified in `kwargs` is paired together.
Like `zip` but returns a dict instead.
Also keyword arguments which are strings or not iterators are automatically repeated N times.
`zip_kw` stops when the shortest iterator has been exhausted.
"""
has_collection = False
for k, v in kwargs.items():
if not is_collection(v):
kwargs[k] = repeat(v)
else:
has_collection = True
kwargs[k] = iter(v)
counter = count() if has_collection else range(1)
try:
for i in counter:
dict_item = {k: next(v) for k, v in kwargs.items()}
yield dict_item
except StopIteration:
pass
[docs]class EasyAsync:
[docs] def __init__(self, n_workers: int = N_WORKERS_DEFAULT):
"""
Dead simple parallel execution of async coroutines.
`n_workers` are dispatched which asynchronously process each task given.
Args:
n_workers: Number of workers to use
Examples:
Each task sleeps for `i` seconds asyncronosly:
>>> EasyAsync(n_workers = 4)(delayed(trio.sleep)(i) for i in range(10))
Each task calculates `isclose` with a different `a` and `b` paramter. `abs_tol` is set the same for all.
>>> from math import isclose
>>> async def isclose_(a, b, abs_tol): return isclose(a=a, b=b, abs_tol=abs_tol)
>>> EasyAsync(n_workers = 4)(
>>> delayed(isclose_)(**kwargs)
>>> for kwargs in zip_kw(a=range(10), b=range(10)[::-1], abs_tol=4))
"""
self.n_workers = n_workers
self.progress_bar = None
[docs] def __call__(self, tasks: Iterable[delayed], length: Optional[int] = None):
"""
Execute given coroutine `tasks` using workers. The order of output will match the order of input.
Args:
tasks: A sequence of coroutines to execute
length: The number of tasks. If not specified it will try obtain the length automatically.
Used for progress bar
Returns:
A list of outputs from each task. The order of items is determined by the input.
"""
if length is None:
try:
length = len(tasks)
except TypeError:
pass
self.progress_bar = tqdm(total=length, smoothing=0)
self.tasks = enumerate(iter(tasks)) # Force the sequence to be a iterator
self.outputs = []
trio.run(self._worker_nursery)
# Order output and remove idx
self.outputs = sorted(self.outputs, key=lambda e: e[0])
self.outputs = [e[1] for e in self.outputs]
return self.outputs
async def _worker_nursery(self):
"""Start a trio nursery with n workers"""
async with trio.open_nursery() as nursery:
for i in range(self.n_workers):
nursery.start_soon(self._worker)
async def _worker(self):
"""Execute tasks until exhausted"""
for idx, task in self.tasks:
function, args, kwargs = task
output = await function(*args, **kwargs)
self.outputs.append((idx, output))
self._progress_update(1)
def _progress_update(self, n: int):
"""Increment progress bar"""
if self.progress_bar is not None:
self.progress_bar.update(n)