1import time
2from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, _base
3from typing import Any, Callable
4
5from printbuddies import ProgBar
6
7
8class _QuickPool:
9 def __init__(
10 self,
11 functions: list[Callable[..., Any]],
12 args_list: list[tuple[Any, ...]] = [],
13 kwargs_list: list[dict[str, Any]] = [],
14 max_workers: int | None = None,
15 ):
16 """Quickly implement multi-threading/processing with an optional progress bar display.
17
18 #### :params:
19
20 `functions`: A list of functions to be executed.
21
22 `args_list`: A list of tuples where each tuple consists of positional arguments to be passed to each successive function in `functions` at execution time.
23
24 `kwargs_list`: A list of dictionaries where each dictionary consists of keyword arguments to be passed to each successive function in `functions` at execution time.
25
26 `max_workers`: The maximum number of concurrent threads or processes. If `None`, the max available to the system will be used.
27
28 The return values of `functions` will be returned as a list by this class' `execute` method.
29
30 The relative ordering of `functions`, `args_list`, and `kwargs_list` matters as `args_list` and `kwargs_list` will be distributed to each function squentially.
31
32 i.e.
33 >>> for function_, args, kwargs in zip(functions, args_list, kwargs_list):
34 >>> function_(*args, **kwargs)
35
36 If `args_list` and/or `kwargs_list` are shorter than the `functions` list, empty tuples and dictionaries will be added to them, respectively.
37
38 e.g
39 >>> import time
40 >>> def dummy(seconds: int, return_val: int)->int:
41 >>> time.sleep(seconds)
42 >>> return return_val
43 >>> num = 10
44 >>> pool = ThreadPool([dummy]*10, [(i,) for i in range(num)], [{"return_val": i} for i in range(num)])
45 >>> results = pool.execute()
46 >>> print(results)
47 >>> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"""
48 self.functions = functions
49 self.args_list = args_list
50 self.kwargs_list = kwargs_list
51 self.max_workers = max_workers
52
53 @property
54 def executor(self) -> _base.Executor:
55 raise NotImplementedError
56
57 @property
58 def submissions(self) -> list:
59 num_functions = len(self.functions)
60 num_args = len(self.args_list)
61 num_kwargs = len(self.kwargs_list)
62 functions = self.functions
63 args_list = self.args_list
64 kwargs_list = self.kwargs_list
65 # Pad args_list and kwargs_list if they're shorter than len(functions)
66 if num_args < num_functions:
67 args_list.extend([tuple() for _ in range(num_functions - num_args)])
68 if num_kwargs < num_functions:
69 kwargs_list.extend([dict() for _ in range(num_functions - num_kwargs)])
70 return [
71 (function_, args, kwargs)
72 for function_, args, kwargs in zip(functions, args_list, kwargs_list)
73 ]
74
75 def execute(
76 self, show_progbar: bool = True, progbar_update_period: float = 0.001
77 ) -> list[Any]:
78 """Execute the supplied functions with their arguments, if any.
79
80 Returns a list of function call results.
81
82 #### :params:
83
84 `show_progbar`: If `True`, print a progress bar to the terminal showing how many functions have finished executing.
85
86 `progbar_update_period`: How often, in seconds, to check the number of completed functions. Only relevant if `show_progbar` is `True`."""
87 with self.executor as executor:
88 workers = [
89 executor.submit(submission[0], *submission[1], **submission[2])
90 for submission in self.submissions
91 ]
92 if show_progbar:
93 num_workers = len(workers)
94 with ProgBar(num_workers) as bar:
95 while (
96 num_complete := len(
97 [worker for worker in workers if worker.done()]
98 )
99 ) < num_workers:
100
101 bar.display(f"{bar.runtime}", counter_override=num_complete)
102 time.sleep(progbar_update_period)
103 bar.display(f"{bar.runtime}", counter_override=num_complete)
104 return [worker.result() for worker in workers]
105
106
107class ProcessPool(_QuickPool):
108 @property
109 def executor(self) -> ProcessPoolExecutor:
110 return ProcessPoolExecutor(self.max_workers)
111
112
113class ThreadPool(_QuickPool):
114 @property
115 def executor(self) -> ThreadPoolExecutor:
116 return ThreadPoolExecutor(self.max_workers)