Coverage for muutils\spinner.py: 87%
116 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-15 21:53 -0600
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-15 21:53 -0600
1"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner
3using the base `Spinner` class while some code is running.
4"""
6import os
7import time
8import threading
9import sys
10from functools import wraps
11from typing import (
12 Callable,
13 Any,
14 Optional,
15 TextIO,
16 TypeVar,
17 Sequence,
18 Dict,
19 Union,
20 ContextManager,
21)
23DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any])
24"Define a generic type for the decorated function"
27SPINNER_CHARS: Dict[str, Sequence[str]] = dict(
28 default=["|", "/", "-", "\\"],
29 dots=[". ", ".. ", "..."],
30 bars=["| ", "|| ", "|||"],
31 arrows=["<", "^", ">", "v"],
32 arrows_2=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"],
33 bouncing_bar=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"],
34 bouncing_ball=[
35 "( ● )",
36 "( ● )",
37 "( ● )",
38 "( ● )",
39 "( ●)",
40 "( ● )",
41 "( ● )",
42 "( ● )",
43 "( ● )",
44 "(● )",
45 ],
46 ooo=[".", "o", "O", "o"],
47 braille=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
48 clock=["🕛", "🕐", "🕑", "🕒", "🕓", "🕔", "🕕", "🕖", "🕗", "🕘", "🕙", "🕚"],
49 hourglass=["⏳", "⌛"],
50 square_corners=["◰", "◳", "◲", "◱"],
51 triangle=["◢", "◣", "◤", "◥"],
52 square_dot=[
53 "⣷",
54 "⣯",
55 "⣟",
56 "⡿",
57 "⢿",
58 "⣻",
59 "⣽",
60 "⣾",
61 ],
62 box_bounce=["▌", "▀", "▐", "▄"],
63 hamburger=["☱", "☲", "☴"],
64 earth=["🌍", "🌎", "🌏"],
65 growing_dots=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"],
66 dice=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"],
67 wifi=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"],
68 bounce=["⠁", "⠂", "⠄", "⠂"],
69 arc=["◜", "◠", "◝", "◞", "◡", "◟"],
70 toggle=["⊶", "⊷"],
71 toggle2=["▫", "▪"],
72 toggle3=["□", "■"],
73 toggle4=["■", "□", "▪", "▫"],
74 toggle5=["▮", "▯"],
75 toggle7=["⦾", "⦿"],
76 toggle8=["◍", "◌"],
77 toggle9=["◉", "◎"],
78 arrow2=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "],
79 point=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"],
80 layer=["-", "=", "≡"],
81 speaker=["🔈 ", "🔉 ", "🔊 ", "🔉 "],
82 orangePulse=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "],
83 bluePulse=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "],
84 satellite_signal=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"],
85 rocket_orbit=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"],
86 ogham=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"],
87 eth=["᛫", "፡", "፥", "፤", "፧", "።", "፨"],
88)
89"""dict of spinner sequences to show. some from Claude 3.5 Sonnet,
90some from [cli-spinners](https://github.com/sindresorhus/cli-spinners)
91"""
93SPINNER_COMPLETE: Dict[str, str] = dict(
94 default="#",
95 dots="***",
96 bars="|||",
97 bouncing_bar="[====]",
98 bouncing_ball="(●●●●●●)",
99 braille="⣿",
100 clock="✔️",
101 hourglass="✔️",
102 square_corners="◼",
103 triangle="◆",
104 square_dot="⣿",
105 box_bounce="■",
106 hamburger="☰",
107 earth="✔️",
108 growing_dots="⣿",
109 dice="🎲",
110 wifi="✔️",
111 arc="○",
112 toggle="-",
113 toggle2="▪",
114 toggle3="■",
115 toggle4="■",
116 toggle5="▮",
117 toggle6="၀",
118 toggle7="⦿",
119 toggle8="◍",
120 toggle9="◉",
121 arrow2="➡️",
122 point="●●●",
123 layer="≡",
124 speaker="🔊",
125 orangePulse="🟠",
126 bluePulse="🔵",
127 satellite_signal="📡 ✔️ ",
128 rocket_orbit="🌍 ✨",
129 ogham="᚛᚜",
130 eth="፠",
131)
132"string to display when the spinner is complete"
135class Spinner:
136 """displays a spinner, and optionally elapsed time and a mutable value while a function is running.
138 # Parameters:
139 - `spinner_chars : Union[str, Sequence[str]]`
140 sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters
141 (defaults to `"default"`)
142 - `update_interval : float`
143 how often to update the spinner display in seconds
144 (defaults to `0.1`)
145 - `spinner_complete : str`
146 string to display when the spinner is complete
147 (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`)
148 - `initial_value : str`
149 initial value to display with the spinner
150 (defaults to `""`)
151 - `message : str`
152 message to display with the spinner
153 (defaults to `""`)
154 - `format_string : str`
155 string to format the spinner with. must have `"\\r"` prepended to clear the line.
156 allowed keys are `spinner`, `elapsed_time`, `message`, and `value`
157 (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`)
158 - `output_stream : TextIO`
159 stream to write the spinner to
160 (defaults to `sys.stdout`)
161 - `format_string_when_updated : Union[bool,str]`
162 whether to use a different format string when the value is updated.
163 if `True`, use the default format string with a newline appended. if a string, use that string.
164 this is useful if you want update_value to print to console and be preserved.
165 (defaults to `False`)
167 # Methods:
168 - `update_value(value: Any) -> None`
169 update the current value displayed by the spinner
171 # Usage:
173 ## As a context manager:
174 ```python
175 with SpinnerContext() as sp:
176 for i in range(1):
177 time.sleep(0.1)
178 spinner.update_value(f"Step {i+1}")
179 ```
181 ## As a decorator:
182 ```python
183 @spinner_decorator
184 def long_running_function():
185 for i in range(1):
186 time.sleep(0.1)
187 spinner.update_value(f"Step {i+1}")
188 return "Function completed"
189 ```
190 """
192 def __init__(
193 self,
194 *args,
195 spinner_chars: Union[str, Sequence[str]] = "default",
196 update_interval: float = 0.1,
197 spinner_complete: Optional[str] = None,
198 initial_value: str = "",
199 message: str = "",
200 format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}",
201 output_stream: TextIO = sys.stdout,
202 format_string_when_updated: Union[str, bool] = False,
203 **kwargs: Any,
204 ):
205 if args:
206 raise ValueError(f"Spinner does not accept positional arguments: {args}")
207 if kwargs:
208 raise ValueError(
209 f"Spinner did not recognize these keyword arguments: {kwargs}"
210 )
212 # spinner display
213 self.spinner_complete: str = (
214 (
215 # if None, use `spinner_chars` key as default
216 SPINNER_COMPLETE.get(spinner_chars, "#")
217 if isinstance(spinner_chars, str)
218 else "#"
219 )
220 if spinner_complete is None
221 # if not None, use the value provided
222 else spinner_complete
223 )
224 "string to display when the spinner is complete"
226 self.spinner_chars: Sequence[str] = (
227 SPINNER_CHARS[spinner_chars]
228 if isinstance(spinner_chars, str)
229 else spinner_chars
230 )
231 "sequence of strings to use as the spinner characters"
233 # special format string for when the value is updated
234 self.format_string_when_updated: Optional[str] = None
235 "format string to use when the value is updated"
236 if format_string_when_updated is not False:
237 if format_string_when_updated is True:
238 # modify the default format string
239 self.format_string_when_updated = format_string + "\n"
240 elif isinstance(format_string_when_updated, str):
241 # use the provided format string
242 self.format_string_when_updated = format_string_when_updated
243 else:
244 raise TypeError(
245 "format_string_when_updated must be a string or True, got"
246 + f" {type(format_string_when_updated) = }{format_string_when_updated}"
247 )
249 # copy other kwargs
250 self.update_interval: float = update_interval
251 self.message: str = message
252 self.current_value: Any = initial_value
253 self.format_string: str = format_string
254 self.output_stream: TextIO = output_stream
256 # test out format string
257 try:
258 self.format_string.format(
259 spinner=self.spinner_chars[0],
260 elapsed_time=0.0,
261 message=self.message,
262 value=self.current_value,
263 )
264 except Exception as e:
265 raise ValueError(
266 f"Invalid format string: {format_string}. Must take keys "
267 + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'."
268 ) from e
270 # init
271 self.start_time: float = 0
272 "for measuring elapsed time"
273 self.stop_spinner: threading.Event = threading.Event()
274 "to stop the spinner"
275 self.spinner_thread: Optional[threading.Thread] = None
276 "the thread running the spinner"
277 self.value_changed: bool = False
278 "whether the value has been updated since the last display"
279 self.term_width: int
280 "width of the terminal, for padding with spaces"
281 try:
282 self.term_width = os.get_terminal_size().columns
283 except OSError:
284 self.term_width = 80
286 def spin(self) -> None:
287 "Function to run in a separate thread, displaying the spinner and optional information"
288 i: int = 0
289 while not self.stop_spinner.is_set():
290 # get current spinner str
291 spinner: str = self.spinner_chars[i % len(self.spinner_chars)]
293 # args for display string
294 display_parts: dict[str, Any] = dict(
295 spinner=spinner, # str
296 elapsed_time=time.time() - self.start_time, # float
297 message=self.message, # str
298 value=self.current_value, # Any, but will be formatted as str
299 )
301 # use the special one if needed
302 format_str: str = self.format_string
303 if self.value_changed and (self.format_string_when_updated is not None):
304 self.value_changed = False
305 format_str = self.format_string_when_updated
307 # write and flush the display string
308 output: str = format_str.format(**display_parts).ljust(self.term_width)
309 self.output_stream.write(output)
310 self.output_stream.flush()
312 # wait for the next update
313 time.sleep(self.update_interval)
314 i += 1
316 def update_value(self, value: Any) -> None:
317 "Update the current value displayed by the spinner"
318 self.current_value = value
319 self.value_changed = True
321 def start(self) -> None:
322 "Start the spinner"
323 self.start_time = time.time()
324 self.spinner_thread = threading.Thread(target=self.spin)
325 self.spinner_thread.start()
327 def stop(self) -> None:
328 "Stop the spinner"
329 self.output_stream.write(
330 self.format_string.format(
331 spinner=self.spinner_complete,
332 elapsed_time=time.time() - self.start_time, # float
333 message=self.message, # str
334 value=self.current_value, # Any, but will be formatted as str
335 ).ljust(self.term_width)
336 )
337 self.stop_spinner.set()
338 if self.spinner_thread:
339 self.spinner_thread.join()
340 self.output_stream.write("\n")
341 self.output_stream.flush()
344class NoOpContextManager(ContextManager):
345 """A context manager that does nothing."""
347 def __init__(self, *args, **kwargs):
348 pass
350 def __enter__(self):
351 return self
353 def __exit__(self, exc_type, exc_value, traceback):
354 pass
357class SpinnerContext(Spinner, ContextManager):
358 "see `Spinner` for parameters"
360 def __enter__(self) -> "SpinnerContext":
361 self.start()
362 return self
364 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
365 self.stop()
368SpinnerContext.__doc__ = Spinner.__doc__
371def spinner_decorator(
372 *args,
373 # passed to `Spinner.__init__`
374 spinner_chars: Union[str, Sequence[str]] = "default",
375 update_interval: float = 0.1,
376 spinner_complete: Optional[str] = None,
377 initial_value: str = "",
378 message: str = "",
379 format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}",
380 output_stream: TextIO = sys.stdout,
381 # new kwarg
382 mutable_kwarg_key: Optional[str] = None,
383 **kwargs,
384) -> Callable[[DecoratedFunction], DecoratedFunction]:
385 """see `Spinner` for parameters. Also takes `mutable_kwarg_key`
387 `mutable_kwarg_key` is the key with which `Spinner().update_value`
388 will be passed to the decorated function. if `None`, won't pass it.
390 """
392 if len(args) > 1:
393 raise ValueError(
394 f"spinner_decorator does not accept positional arguments: {args}"
395 )
396 if kwargs:
397 raise ValueError(
398 f"spinner_decorator did not recognize these keyword arguments: {kwargs}"
399 )
401 def decorator(func: DecoratedFunction) -> DecoratedFunction:
402 @wraps(func)
403 def wrapper(*args: Any, **kwargs: Any) -> Any:
404 spinner: Spinner = Spinner(
405 spinner_chars=spinner_chars,
406 update_interval=update_interval,
407 spinner_complete=spinner_complete,
408 initial_value=initial_value,
409 message=message,
410 format_string=format_string,
411 output_stream=output_stream,
412 )
414 if mutable_kwarg_key:
415 kwargs[mutable_kwarg_key] = spinner.update_value
417 spinner.start()
418 try:
419 result: Any = func(*args, **kwargs)
420 finally:
421 spinner.stop()
423 return result
425 # TODO: fix this type ignore
426 return wrapper # type: ignore[return-value]
428 if not args:
429 # called as `@spinner_decorator(stuff)`
430 return decorator
431 else:
432 # called as `@spinner_decorator` without parens
433 return decorator(args[0])
436spinner_decorator.__doc__ = Spinner.__doc__