Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/streams/progress.py: 78%
54 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 17:10 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 17:10 +0200
1import logging
2import os
3from dataclasses import dataclass, field
4from datetime import datetime, timedelta
5from io import StringIO
7from anyio import TypedAttributeSet, typed_attribute
8from anyio.abc import ObjectStream
9from rich.console import Console
10from rich.progress import (
11 BarColumn,
12 DownloadColumn,
13 Progress,
14 TaskID,
15 TextColumn,
16 TimeElapsedColumn,
17 TimeRemainingColumn,
18 TransferSpeedColumn,
19)
21logger = logging.getLogger(__name__)
24class ProgressAttribute(TypedAttributeSet):
25 total: float = typed_attribute()
28@dataclass(kw_only=True)
29class ProgressStream(ObjectStream[bytes]):
30 stream: ObjectStream
31 logging: bool = False
33 __prog: Progress | None = field(init=False, default=None)
34 __recv: TaskID | None = field(init=False, default=None)
35 __send: TaskID | None = field(init=False, default=None)
36 __last: datetime = field(init=False, default_factory=datetime.now)
38 def __post_init__(self):
39 if hasattr(super(), "__post_init__"):
40 super().__post_init__()
42 self.__prog = Progress(
43 TextColumn("[progress.description]{task.description}"),
44 BarColumn(),
45 TransferSpeedColumn(),
46 DownloadColumn(),
47 TextColumn("Elapsed:"),
48 TimeElapsedColumn(),
49 TextColumn("Remaining:"),
50 TimeRemainingColumn(),
51 disable=self.logging or os.environ.get("TERM") == "dumb",
52 )
54 def __del__(self):
55 if self.__prog.live.is_started:
56 self.__prog.stop()
58 async def receive(self):
59 if self.__recv is None:
60 self.__prog.start()
61 self.__recv = self.__prog.add_task(
62 "transfer",
63 total=self.stream.extra(ProgressAttribute.total, None),
64 )
66 item = await self.stream.receive()
68 self.__prog.advance(self.__recv, len(item))
69 if self.logging and (datetime.now() - self.__last > timedelta(seconds=2)):
70 self.__last = datetime.now()
71 console = Console(file=StringIO())
72 console.print(self.__prog.get_renderable())
73 logger.info(console.file.getvalue().rstrip())
75 return item
77 async def send(self, item):
78 if self.__send is None:
79 self.__prog.start()
80 self.__send = self.__prog.add_task(
81 "transfer",
82 total=self.stream.extra(ProgressAttribute.total, None),
83 )
85 self.__prog.advance(self.__recv, len(item))
86 if self.logging and (datetime.now() - self.__last > timedelta(seconds=2)):
87 self.__last = datetime.now()
88 console = Console(file=StringIO())
89 console.print(self.__prog.get_renderable())
90 logger.info(console.file.getvalue().rstrip())
92 await self.stream.send(item)
94 async def send_eof(self):
95 await self.stream.send_eof()
97 async def aclose(self):
98 await self.stream.aclose()