Coverage for muutils\spinner.py: 87%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-15 20:56 -0600

1"""decorator `spinner_decorator` and context manager `SpinnerContext` to display a spinner 

2 

3using the base `Spinner` class while some code is running. 

4""" 

5 

6import os 

7import time 

8import threading 

9import sys 

10from functools import wraps 

11from typing import ( 

12 Callable, 

13 Any, 

14 Optional, 

15 TextIO, 

16 TypeVar, 

17 Sequence, 

18 Dict, 

19 Union, 

20 ContextManager, 

21) 

22 

23DecoratedFunction = TypeVar("DecoratedFunction", bound=Callable[..., Any]) 

24"Define a generic type for the decorated function" 

25 

26 

27SPINNER_CHARS: Dict[str, Sequence[str]] = dict( 

28 default=["|", "/", "-", "\\"], 

29 dots=[". ", ".. ", "..."], 

30 bars=["| ", "|| ", "|||"], 

31 arrows=["<", "^", ">", "v"], 

32 arrows_2=["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], 

33 bouncing_bar=["[ ]", "[= ]", "[== ]", "[=== ]", "[ ===]", "[ ==]", "[ =]"], 

34 bouncing_ball=[ 

35 "( ● )", 

36 "( ● )", 

37 "( ● )", 

38 "( ● )", 

39 "( ●)", 

40 "( ● )", 

41 "( ● )", 

42 "( ● )", 

43 "( ● )", 

44 "(● )", 

45 ], 

46 ooo=[".", "o", "O", "o"], 

47 braille=["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], 

48 clock=["🕛", "🕐", "🕑", "🕒", "🕓", "🕔", "🕕", "🕖", "🕗", "🕘", "🕙", "🕚"], 

49 hourglass=["⏳", "⌛"], 

50 square_corners=["◰", "◳", "◲", "◱"], 

51 triangle=["◢", "◣", "◤", "◥"], 

52 square_dot=[ 

53 "⣷", 

54 "⣯", 

55 "⣟", 

56 "⡿", 

57 "⢿", 

58 "⣻", 

59 "⣽", 

60 "⣾", 

61 ], 

62 box_bounce=["▌", "▀", "▐", "▄"], 

63 hamburger=["☱", "☲", "☴"], 

64 earth=["🌍", "🌎", "🌏"], 

65 growing_dots=["⣀", "⣄", "⣤", "⣦", "⣶", "⣷", "⣿"], 

66 dice=["⚀", "⚁", "⚂", "⚃", "⚄", "⚅"], 

67 wifi=["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"], 

68 bounce=["⠁", "⠂", "⠄", "⠂"], 

69 arc=["◜", "◠", "◝", "◞", "◡", "◟"], 

70 toggle=["⊶", "⊷"], 

71 toggle2=["▫", "▪"], 

72 toggle3=["□", "■"], 

73 toggle4=["■", "□", "▪", "▫"], 

74 toggle5=["▮", "▯"], 

75 toggle7=["⦾", "⦿"], 

76 toggle8=["◍", "◌"], 

77 toggle9=["◉", "◎"], 

78 arrow2=["⬆️ ", "↗️ ", "➡️ ", "↘️ ", "⬇️ ", "↙️ ", "⬅️ ", "↖️ "], 

79 point=["∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙"], 

80 layer=["-", "=", "≡"], 

81 speaker=["🔈 ", "🔉 ", "🔊 ", "🔉 "], 

82 orangePulse=["🔸 ", "🔶 ", "🟠 ", "🟠 ", "🔷 "], 

83 bluePulse=["🔹 ", "🔷 ", "🔵 ", "🔵 ", "🔷 "], 

84 satellite_signal=["📡 ", "📡· ", "📡·· ", "📡···", "📡 ··", "📡 ·"], 

85 rocket_orbit=["🌍🚀 ", "🌏 🚀 ", "🌎 🚀"], 

86 ogham=["ᚁ ", "ᚂ ", "ᚃ ", "ᚄ", "ᚅ"], 

87 eth=["᛫", "፡", "፥", "፤", "፧", "።", "፨"], 

88) 

89"""dict of spinner sequences to show. some from Claude 3.5 Sonnet, 

90some from [cli-spinners](https://github.com/sindresorhus/cli-spinners) 

91""" 

92 

93SPINNER_COMPLETE: Dict[str, str] = dict( 

94 default="#", 

95 dots="***", 

96 bars="|||", 

97 bouncing_bar="[====]", 

98 bouncing_ball="(●●●●●●)", 

99 braille="⣿", 

100 clock="✔️", 

101 hourglass="✔️", 

102 square_corners="◼", 

103 triangle="◆", 

104 square_dot="⣿", 

105 box_bounce="■", 

106 hamburger="☰", 

107 earth="✔️", 

108 growing_dots="⣿", 

109 dice="🎲", 

110 wifi="✔️", 

111 arc="○", 

112 toggle="-", 

113 toggle2="▪", 

114 toggle3="■", 

115 toggle4="■", 

116 toggle5="▮", 

117 toggle6="၀", 

118 toggle7="⦿", 

119 toggle8="◍", 

120 toggle9="◉", 

121 arrow2="➡️", 

122 point="●●●", 

123 layer="≡", 

124 speaker="🔊", 

125 orangePulse="🟠", 

126 bluePulse="🔵", 

127 satellite_signal="📡 ✔️ ", 

128 rocket_orbit="🌍 ✨", 

129 ogham="᚛᚜", 

130 eth="፠", 

131) 

132"string to display when the spinner is complete" 

133 

134 

135class Spinner: 

136 """displays a spinner, and optionally elapsed time and a mutable value while a function is running. 

137 

138 # Parameters: 

139 - `spinner_chars : Union[str, Sequence[str]]` 

140 sequence of strings, or key to look up in `SPINNER_CHARS`, to use as the spinner characters 

141 (defaults to `"default"`) 

142 - `update_interval : float` 

143 how often to update the spinner display in seconds 

144 (defaults to `0.1`) 

145 - `spinner_complete : str` 

146 string to display when the spinner is complete 

147 (defaults to looking up `spinner_chars` in `SPINNER_COMPLETE` or `"#"`) 

148 - `initial_value : str` 

149 initial value to display with the spinner 

150 (defaults to `""`) 

151 - `message : str` 

152 message to display with the spinner 

153 (defaults to `""`) 

154 - `format_string : str` 

155 string to format the spinner with. must have `"\\r"` prepended to clear the line. 

156 allowed keys are `spinner`, `elapsed_time`, `message`, and `value` 

157 (defaults to `"\\r{spinner} ({elapsed_time:.2f}s) {message}{value}"`) 

158 - `output_stream : TextIO` 

159 stream to write the spinner to 

160 (defaults to `sys.stdout`) 

161 - `format_string_when_updated : Union[bool,str]` 

162 whether to use a different format string when the value is updated. 

163 if `True`, use the default format string with a newline appended. if a string, use that string. 

164 this is useful if you want update_value to print to console and be preserved. 

165 (defaults to `False`) 

166 

167 # Methods: 

168 - `update_value(value: Any) -> None` 

169 update the current value displayed by the spinner 

170 

171 # Usage: 

172 

173 ## As a context manager: 

174 ```python 

175 with SpinnerContext() as sp: 

176 for i in range(1): 

177 time.sleep(0.1) 

178 spinner.update_value(f"Step {i+1}") 

179 ``` 

180 

181 ## As a decorator: 

182 ```python 

183 @spinner_decorator 

184 def long_running_function(): 

185 for i in range(1): 

186 time.sleep(0.1) 

187 spinner.update_value(f"Step {i+1}") 

188 return "Function completed" 

189 ``` 

190 """ 

191 

192 def __init__( 

193 self, 

194 *args, 

195 spinner_chars: Union[str, Sequence[str]] = "default", 

196 update_interval: float = 0.1, 

197 spinner_complete: Optional[str] = None, 

198 initial_value: str = "", 

199 message: str = "", 

200 format_string: str = "\r{spinner} ({elapsed_time:.2f}s) {message}{value}", 

201 output_stream: TextIO = sys.stdout, 

202 format_string_when_updated: Union[str, bool] = False, 

203 **kwargs: Any, 

204 ): 

205 if args: 

206 raise ValueError(f"Spinner does not accept positional arguments: {args}") 

207 if kwargs: 

208 raise ValueError( 

209 f"Spinner did not recognize these keyword arguments: {kwargs}" 

210 ) 

211 

212 # spinner display 

213 self.spinner_complete: str = ( 

214 ( 

215 # if None, use `spinner_chars` key as default 

216 SPINNER_COMPLETE.get(spinner_chars, "#") 

217 if isinstance(spinner_chars, str) 

218 else "#" 

219 ) 

220 if spinner_complete is None 

221 # if not None, use the value provided 

222 else spinner_complete 

223 ) 

224 "string to display when the spinner is complete" 

225 

226 self.spinner_chars: Sequence[str] = ( 

227 SPINNER_CHARS[spinner_chars] 

228 if isinstance(spinner_chars, str) 

229 else spinner_chars 

230 ) 

231 "sequence of strings to use as the spinner characters" 

232 

233 # special format string for when the value is updated 

234 self.format_string_when_updated: Optional[str] = None 

235 "format string to use when the value is updated" 

236 if format_string_when_updated is not False: 

237 if format_string_when_updated is True: 

238 # modify the default format string 

239 self.format_string_when_updated = format_string + "\n" 

240 elif isinstance(format_string_when_updated, str): 

241 # use the provided format string 

242 self.format_string_when_updated = format_string_when_updated 

243 else: 

244 raise TypeError( 

245 "format_string_when_updated must be a string or True, got" 

246 + f" {type(format_string_when_updated) = }{format_string_when_updated}" 

247 ) 

248 

249 # copy other kwargs 

250 self.update_interval: float = update_interval 

251 self.message: str = message 

252 self.current_value: Any = initial_value 

253 self.format_string: str = format_string 

254 self.output_stream: TextIO = output_stream 

255 

256 # test out format string 

257 try: 

258 self.format_string.format( 

259 spinner=self.spinner_chars[0], 

260 elapsed_time=0.0, 

261 message=self.message, 

262 value=self.current_value, 

263 ) 

264 except Exception as e: 

265 raise ValueError( 

266 f"Invalid format string: {format_string}. Must take keys " 

267 + "'spinner: str', 'elapsed_time: float', 'message: str', and 'value: Any'." 

268 ) from e 

269 

270 # init 

271 self.start_time: float = 0 

272 "for measuring elapsed time" 

273 self.stop_spinner: threading.Event = threading.Event() 

274 "to stop the spinner" 

275 self.spinner_thread: Optional[threading.Thread] = None 

276 "the thread running the spinner" 

277 self.value_changed: bool = False 

278 "whether the value has been updated since the last display" 

279 self.term_width: int 

280 "width of the terminal, for padding with spaces" 

281 try: 

282 self.term_width = os.get_terminal_size().columns 

283 except OSError: 

284 self.term_width = 80 

285 

286 def spin(self) -> None: 

287 "Function to run in a separate thread, displaying the spinner and optional information" 

288 i: int = 0 

289 while not self.stop_spinner.is_set(): 

290 # get current spinner str 

291 spinner: str = self.spinner_chars[i % len(self.spinner_chars)] 

292 

293 # args for display string 

294 display_parts: dict[str, Any] = dict( 

295 spinner=spinner, # str 

296 elapsed_time=time.time() - self.start_time, # float 

297 message=self.message, # str 

298 value=self.current_value, # Any, but will be formatted as str 

299 ) 

300 

301 # use the special one if needed 

302 format_str: str = self.format_string 

303 if self.value_changed and (self.format_string_when_updated is not None): 

304 self.value_changed = False 

305 format_str = self.format_string_when_updated 

306 

307 # write and flush the display string 

308 output: str = format_str.format(**display_parts).ljust(self.term_width) 

309 self.output_stream.write(output) 

310 self.output_stream.flush() 

311 

312 # wait for the next update 

313 time.sleep(self.update_interval) 

314 i += 1 

315 

316 def update_value(self, value: Any) -> None: 

317 "Update the current value displayed by the spinner" 

318 self.current_value = value 

319 self.value_changed = True 

320 

321 def start(self) -> None: 

322 "Start the spinner" 

323 self.start_time = time.time() 

324 self.spinner_thread = threading.Thread(target=self.spin) 

325 self.spinner_thread.start() 

326 

327 def stop(self) -> None: 

328 "Stop the spinner" 

329 self.output_stream.write( 

330 self.format_string.format( 

331 spinner=self.spinner_complete, 

332 elapsed_time=time.time() - self.start_time, # float 

333 message=self.message, # str 

334 value=self.current_value, # Any, but will be formatted as str 

335 ).ljust(self.term_width) 

336 ) 

337 self.stop_spinner.set() 

338 if self.spinner_thread: 

339 self.spinner_thread.join() 

340 self.output_stream.write("\n") 

341 self.output_stream.flush() 

342 

343 

344class NoOpContextManager(ContextManager): 

345 """A context manager that does nothing.""" 

346 

347 def __init__(self, *args, **kwargs): 

348 pass 

349 

350 def __enter__(self): 

351 return self 

352 

353 def __exit__(self, exc_type, exc_value, traceback): 

354 pass 

355 

356 

357class SpinnerContext(Spinner, ContextManager): 

358 "see `Spinner` for parameters" 

359 

360 def __enter__(self) -> "SpinnerContext": 

361 self.start() 

362 return self 

363 

364 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

365 self.stop() 

366 

367 

368SpinnerContext.__doc__ = Spinner.__doc__ 

369 

370 

371def spinner_decorator( 

372 *args, 

373 # passed to `Spinner.__init__` 

374 spinner_chars: Union[str, Sequence[str]] = "default", 

375 update_interval: float = 0.1, 

376 spinner_complete: Optional[str] = None, 

377 initial_value: str = "", 

378 message: str = "", 

379 format_string: str = "{spinner} ({elapsed_time:.2f}s) {message}{value}", 

380 output_stream: TextIO = sys.stdout, 

381 # new kwarg 

382 mutable_kwarg_key: Optional[str] = None, 

383 **kwargs, 

384) -> Callable[[DecoratedFunction], DecoratedFunction]: 

385 """see `Spinner` for parameters. Also takes `mutable_kwarg_key` 

386 

387 `mutable_kwarg_key` is the key with which `Spinner().update_value` 

388 will be passed to the decorated function. if `None`, won't pass it. 

389 

390 """ 

391 

392 if len(args) > 1: 

393 raise ValueError( 

394 f"spinner_decorator does not accept positional arguments: {args}" 

395 ) 

396 if kwargs: 

397 raise ValueError( 

398 f"spinner_decorator did not recognize these keyword arguments: {kwargs}" 

399 ) 

400 

401 def decorator(func: DecoratedFunction) -> DecoratedFunction: 

402 @wraps(func) 

403 def wrapper(*args: Any, **kwargs: Any) -> Any: 

404 spinner: Spinner = Spinner( 

405 spinner_chars=spinner_chars, 

406 update_interval=update_interval, 

407 spinner_complete=spinner_complete, 

408 initial_value=initial_value, 

409 message=message, 

410 format_string=format_string, 

411 output_stream=output_stream, 

412 ) 

413 

414 if mutable_kwarg_key: 

415 kwargs[mutable_kwarg_key] = spinner.update_value 

416 

417 spinner.start() 

418 try: 

419 result: Any = func(*args, **kwargs) 

420 finally: 

421 spinner.stop() 

422 

423 return result 

424 

425 # TODO: fix this type ignore 

426 return wrapper # type: ignore[return-value] 

427 

428 if not args: 

429 # called as `@spinner_decorator(stuff)` 

430 return decorator 

431 else: 

432 # called as `@spinner_decorator` without parens 

433 return decorator(args[0]) 

434 

435 

436spinner_decorator.__doc__ = Spinner.__doc__