Coverage for muutils\spinner.py: 87%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-09 01:48 -0600

1"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner 

2 

3using the base `Spinner` class while some code is running. 

4""" 

5 

6import os 

7import time 

8import threading 

9import sys 

10from functools import wraps 

11from typing import Callable, Any, Optional, TextIO, TypeVar, Sequence, Dict, Union 

12 

13DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any]) 

14"Define a generic type for the decorated function" 

15 

16 

17SPINNER_CHARS: Dict[str, Sequence[str]] = dict( 

18 default=["|", "/", "-", "\\"], 

19 dots=[". ", ".. ", "..."], 

20 bars=["| ", "|| ", "|||"], 

21 arrows=["<", "^", ">", "v"], 

22 arrows_2=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], 

23 bouncing_bar=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"], 

24 bouncing_ball=[ 

25 "( ● )", 

26 "( ● )", 

27 "( ● )", 

28 "( ● )", 

29 "( ●)", 

30 "( ● )", 

31 "( ● )", 

32 "( ● )", 

33 "( ● )", 

34 "(● )", 

35 ], 

36 ooo=[".", "o", "O", "o"], 

37 braille=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], 

38 clock=["🕛", "🕐", "🕑", "🕒", "🕓", "🕔", "🕕", "🕖", "🕗", "🕘", "🕙", "🕚"], 

39 hourglass=["⏳", "⌛"], 

40 square_corners=["◰", "◳", "◲", "◱"], 

41 triangle=["◢", "◣", "◤", "◥"], 

42 square_dot=[ 

43 "⣷", 

44 "⣯", 

45 "⣟", 

46 "⡿", 

47 "⢿", 

48 "⣻", 

49 "⣽", 

50 "⣾", 

51 ], 

52 box_bounce=["▌", "▀", "▐", "▄"], 

53 hamburger=["☱", "☲", "☴"], 

54 earth=["🌍", "🌎", "🌏"], 

55 growing_dots=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], 

56 dice=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], 

57 wifi=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], 

58 bounce=["⠁", "⠂", "⠄", "⠂"], 

59 arc=["◜", "◠", "◝", "◞", "◡", "◟"], 

60 toggle=["⊶", "⊷"], 

61 toggle2=["▫", "▪"], 

62 toggle3=["□", "■"], 

63 toggle4=["■", "□", "▪", "▫"], 

64 toggle5=["▮", "▯"], 

65 toggle7=["⦾", "⦿"], 

66 toggle8=["◍", "◌"], 

67 toggle9=["◉", "◎"], 

68 arrow2=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], 

69 point=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], 

70 layer=["-", "=", "≡"], 

71 speaker=["🔈 ", "🔉 ", "🔊 ", "🔉 "], 

72 orangePulse=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], 

73 bluePulse=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], 

74 satellite_signal=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"], 

75 rocket_orbit=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"], 

76 ogham=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], 

77 eth=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], 

78) 

79"""dict of spinner sequences to show. some from Claude 3.5 Sonnet, 

80some from [cli-spinners](https://github.com/sindresorhus/cli-spinners) 

81""" 

82 

83SPINNER_COMPLETE: Dict[str, str] = dict( 

84 default="#", 

85 dots="***", 

86 bars="|||", 

87 bouncing_bar="[====]", 

88 bouncing_ball="(●●●●●●)", 

89 braille="⣿", 

90 clock="✔️", 

91 hourglass="✔️", 

92 square_corners="◼", 

93 triangle="◆", 

94 square_dot="⣿", 

95 box_bounce="■", 

96 hamburger="☰", 

97 earth="✔️", 

98 growing_dots="⣿", 

99 dice="🎲", 

100 wifi="✔️", 

101 arc="○", 

102 toggle="-", 

103 toggle2="▪", 

104 toggle3="■", 

105 toggle4="■", 

106 toggle5="▮", 

107 toggle6="၀", 

108 toggle7="⦿", 

109 toggle8="◍", 

110 toggle9="◉", 

111 arrow2="➡️", 

112 point="●●●", 

113 layer="≡", 

114 speaker="🔊", 

115 orangePulse="🟠", 

116 bluePulse="🔵", 

117 satellite_signal="📡 ✔️ ", 

118 rocket_orbit="🌍 ✨", 

119 ogham="᚛᚜", 

120 eth="፠", 

121) 

122"string to display when the spinner is complete" 

123 

124 

125class Spinner: 

126 """displays a spinner, and optionally elapsed time and a mutable value while a function is running. 

127 

128 # Parameters: 

129 - `spinner_chars : Union[str, Sequence[str]]` 

130 sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters 

131 (defaults to `"default"`) 

132 - `update_interval : float` 

133 how often to update the spinner display in seconds 

134 (defaults to `0.1`) 

135 - `spinner_complete : str` 

136 string to display when the spinner is complete 

137 (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`) 

138 - `initial_value : str` 

139 initial value to display with the spinner 

140 (defaults to `""`) 

141 - `message : str` 

142 message to display with the spinner 

143 (defaults to `""`) 

144 - `format_string : str` 

145 string to format the spinner with. must have `"\\r"` prepended to clear the line. 

146 allowed keys are `spinner`, `elapsed_time`, `message`, and `value` 

147 (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`) 

148 - `output_stream : TextIO` 

149 stream to write the spinner to 

150 (defaults to `sys.stdout`) 

151 - `format_string_when_updated : Union[bool,str]` 

152 whether to use a different format string when the value is updated. 

153 if `True`, use the default format string with a newline appended. if a string, use that string. 

154 this is useful if you want update_value to print to console and be preserved. 

155 (defaults to `False`) 

156 

157 # Methods: 

158 - `update_value(value: Any) -> None` 

159 update the current value displayed by the spinner 

160 

161 # Usage: 

162 

163 ## As a context manager: 

164 ```python 

165 with SpinnerContext() as sp: 

166 for i in range(1): 

167 time.sleep(0.1) 

168 spinner.update_value(f"Step {i+1}") 

169 ``` 

170 

171 ## As a decorator: 

172 ```python 

173 @spinner_decorator 

174 def long_running_function(): 

175 for i in range(1): 

176 time.sleep(0.1) 

177 spinner.update_value(f"Step {i+1}") 

178 return "Function completed" 

179 ``` 

180 """ 

181 

182 def __init__( 

183 self, 

184 *args, 

185 spinner_chars: Union[str, Sequence[str]] = "default", 

186 update_interval: float = 0.1, 

187 spinner_complete: Optional[str] = None, 

188 initial_value: str = "", 

189 message: str = "", 

190 format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}", 

191 output_stream: TextIO = sys.stdout, 

192 format_string_when_updated: Union[str, bool] = False, 

193 **kwargs: Any, 

194 ): 

195 if args: 

196 raise ValueError(f"Spinner does not accept positional arguments: {args}") 

197 if kwargs: 

198 raise ValueError( 

199 f"Spinner did not recognize these keyword arguments: {kwargs}" 

200 ) 

201 

202 # spinner display 

203 self.spinner_complete: str = ( 

204 ( 

205 # if None, use `spinner_chars` key as default 

206 SPINNER_COMPLETE.get(spinner_chars, "#") 

207 if isinstance(spinner_chars, str) 

208 else "#" 

209 ) 

210 if spinner_complete is None 

211 # if not None, use the value provided 

212 else spinner_complete 

213 ) 

214 "string to display when the spinner is complete" 

215 

216 self.spinner_chars: Sequence[str] = ( 

217 SPINNER_CHARS[spinner_chars] 

218 if isinstance(spinner_chars, str) 

219 else spinner_chars 

220 ) 

221 "sequence of strings to use as the spinner characters" 

222 

223 # special format string for when the value is updated 

224 self.format_string_when_updated: Optional[str] = None 

225 "format string to use when the value is updated" 

226 if format_string_when_updated is not False: 

227 if format_string_when_updated is True: 

228 # modify the default format string 

229 self.format_string_when_updated = format_string + "\n" 

230 elif isinstance(format_string_when_updated, str): 

231 # use the provided format string 

232 self.format_string_when_updated = format_string_when_updated 

233 else: 

234 raise TypeError( 

235 "format_string_when_updated must be a string or True, got" 

236 + f" {type(format_string_when_updated) = }{format_string_when_updated}" 

237 ) 

238 

239 # copy other kwargs 

240 self.update_interval: float = update_interval 

241 self.message: str = message 

242 self.current_value: Any = initial_value 

243 self.format_string: str = format_string 

244 self.output_stream: TextIO = output_stream 

245 

246 # test out format string 

247 try: 

248 self.format_string.format( 

249 spinner=self.spinner_chars[0], 

250 elapsed_time=0.0, 

251 message=self.message, 

252 value=self.current_value, 

253 ) 

254 except Exception as e: 

255 raise ValueError( 

256 f"Invalid format string: {format_string}. Must take keys " 

257 + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'." 

258 ) from e 

259 

260 # init 

261 self.start_time: float = 0 

262 "for measuring elapsed time" 

263 self.stop_spinner: threading.Event = threading.Event() 

264 "to stop the spinner" 

265 self.spinner_thread: Optional[threading.Thread] = None 

266 "the thread running the spinner" 

267 self.value_changed: bool = False 

268 "whether the value has been updated since the last display" 

269 self.term_width: int 

270 "width of the terminal, for padding with spaces" 

271 try: 

272 self.term_width = os.get_terminal_size().columns 

273 except OSError: 

274 self.term_width = 80 

275 

276 def spin(self) -> None: 

277 "Function to run in a separate thread, displaying the spinner and optional information" 

278 i: int = 0 

279 while not self.stop_spinner.is_set(): 

280 # get current spinner str 

281 spinner: str = self.spinner_chars[i % len(self.spinner_chars)] 

282 

283 # args for display string 

284 display_parts: dict[str, Any] = dict( 

285 spinner=spinner, # str 

286 elapsed_time=time.time() - self.start_time, # float 

287 message=self.message, # str 

288 value=self.current_value, # Any, but will be formatted as str 

289 ) 

290 

291 # use the special one if needed 

292 format_str: str = self.format_string 

293 if self.value_changed and (self.format_string_when_updated is not None): 

294 self.value_changed = False 

295 format_str = self.format_string_when_updated 

296 

297 # write and flush the display string 

298 output: str = format_str.format(**display_parts).ljust(self.term_width) 

299 self.output_stream.write(output) 

300 self.output_stream.flush() 

301 

302 # wait for the next update 

303 time.sleep(self.update_interval) 

304 i += 1 

305 

306 def update_value(self, value: Any) -> None: 

307 "Update the current value displayed by the spinner" 

308 self.current_value = value 

309 self.value_changed = True 

310 

311 def start(self) -> None: 

312 "Start the spinner" 

313 self.start_time = time.time() 

314 self.spinner_thread = threading.Thread(target=self.spin) 

315 self.spinner_thread.start() 

316 

317 def stop(self) -> None: 

318 "Stop the spinner" 

319 self.output_stream.write( 

320 self.format_string.format( 

321 spinner=self.spinner_complete, 

322 elapsed_time=time.time() - self.start_time, # float 

323 message=self.message, # str 

324 value=self.current_value, # Any, but will be formatted as str 

325 ).ljust(self.term_width) 

326 ) 

327 self.stop_spinner.set() 

328 if self.spinner_thread: 

329 self.spinner_thread.join() 

330 self.output_stream.write("\n") 

331 self.output_stream.flush() 

332 

333 

334class SpinnerContext(Spinner): 

335 "see `Spinner` for parameters" 

336 

337 def __enter__(self) -> "SpinnerContext": 

338 self.start() 

339 return self 

340 

341 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

342 self.stop() 

343 

344 

345SpinnerContext.__doc__ = Spinner.__doc__ 

346 

347 

348def spinner_decorator( 

349 *args, 

350 # passed to `Spinner.__init__` 

351 spinner_chars: Union[str, Sequence[str]] = "default", 

352 update_interval: float = 0.1, 

353 spinner_complete: Optional[str] = None, 

354 initial_value: str = "", 

355 message: str = "", 

356 format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}", 

357 output_stream: TextIO = sys.stdout, 

358 # new kwarg 

359 mutable_kwarg_key: Optional[str] = None, 

360 **kwargs, 

361) -> Callable[[DecoratedFunction], DecoratedFunction]: 

362 """see `Spinner` for parameters. Also takes `mutable_kwarg_key` 

363 

364 `mutable_kwarg_key` is the key with which `Spinner().update_value` 

365 will be passed to the decorated function. if `None`, won't pass it. 

366 

367 """ 

368 

369 if len(args) > 1: 

370 raise ValueError( 

371 f"spinner_decorator does not accept positional arguments: {args}" 

372 ) 

373 if kwargs: 

374 raise ValueError( 

375 f"spinner_decorator did not recognize these keyword arguments: {kwargs}" 

376 ) 

377 

378 def decorator(func: DecoratedFunction) -> DecoratedFunction: 

379 @wraps(func) 

380 def wrapper(*args: Any, **kwargs: Any) -> Any: 

381 spinner: Spinner = Spinner( 

382 spinner_chars=spinner_chars, 

383 update_interval=update_interval, 

384 spinner_complete=spinner_complete, 

385 initial_value=initial_value, 

386 message=message, 

387 format_string=format_string, 

388 output_stream=output_stream, 

389 ) 

390 

391 if mutable_kwarg_key: 

392 kwargs[mutable_kwarg_key] = spinner.update_value 

393 

394 spinner.start() 

395 try: 

396 result: Any = func(*args, **kwargs) 

397 finally: 

398 spinner.stop() 

399 

400 return result 

401 

402 # TODO: fix this type ignore 

403 return wrapper # type: ignore[return-value] 

404 

405 if not args: 

406 # called as `@spinner_decorator(stuff)` 

407 return decorator 

408 else: 

409 # called as `@spinner_decorator` without parens 

410 return decorator(args[0]) 

411 

412 

413spinner_decorator.__doc__ = Spinner.__doc__ 

414 

415 

416class NoOpContextManager: 

417 """A context manager that does nothing.""" 

418 

419 def __init__(self, *args, **kwargs): 

420 pass 

421 

422 def __enter__(self): 

423 return self 

424 

425 def __exit__(self, exc_type, exc_value, traceback): 

426 pass