Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/streams/progress.py: 78%

54 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 17:10 +0200

1import logging 

2import os 

3from dataclasses import dataclass, field 

4from datetime import datetime, timedelta 

5from io import StringIO 

6 

7from anyio import TypedAttributeSet, typed_attribute 

8from anyio.abc import ObjectStream 

9from rich.console import Console 

10from rich.progress import ( 

11 BarColumn, 

12 DownloadColumn, 

13 Progress, 

14 TaskID, 

15 TextColumn, 

16 TimeElapsedColumn, 

17 TimeRemainingColumn, 

18 TransferSpeedColumn, 

19) 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class ProgressAttribute(TypedAttributeSet): 

25 total: float = typed_attribute() 

26 

27 

28@dataclass(kw_only=True) 

29class ProgressStream(ObjectStream[bytes]): 

30 stream: ObjectStream 

31 logging: bool = False 

32 

33 __prog: Progress | None = field(init=False, default=None) 

34 __recv: TaskID | None = field(init=False, default=None) 

35 __send: TaskID | None = field(init=False, default=None) 

36 __last: datetime = field(init=False, default_factory=datetime.now) 

37 

38 def __post_init__(self): 

39 if hasattr(super(), "__post_init__"): 

40 super().__post_init__() 

41 

42 self.__prog = Progress( 

43 TextColumn("[progress.description]{task.description}"), 

44 BarColumn(), 

45 TransferSpeedColumn(), 

46 DownloadColumn(), 

47 TextColumn("Elapsed:"), 

48 TimeElapsedColumn(), 

49 TextColumn("Remaining:"), 

50 TimeRemainingColumn(), 

51 disable=self.logging or os.environ.get("TERM") == "dumb", 

52 ) 

53 

54 def __del__(self): 

55 if self.__prog.live.is_started: 

56 self.__prog.stop() 

57 

58 async def receive(self): 

59 if self.__recv is None: 

60 self.__prog.start() 

61 self.__recv = self.__prog.add_task( 

62 "transfer", 

63 total=self.stream.extra(ProgressAttribute.total, None), 

64 ) 

65 

66 item = await self.stream.receive() 

67 

68 self.__prog.advance(self.__recv, len(item)) 

69 if self.logging and (datetime.now() - self.__last > timedelta(seconds=2)): 

70 self.__last = datetime.now() 

71 console = Console(file=StringIO()) 

72 console.print(self.__prog.get_renderable()) 

73 logger.info(console.file.getvalue().rstrip()) 

74 

75 return item 

76 

77 async def send(self, item): 

78 if self.__send is None: 

79 self.__prog.start() 

80 self.__send = self.__prog.add_task( 

81 "transfer", 

82 total=self.stream.extra(ProgressAttribute.total, None), 

83 ) 

84 

85 self.__prog.advance(self.__recv, len(item)) 

86 if self.logging and (datetime.now() - self.__last > timedelta(seconds=2)): 

87 self.__last = datetime.now() 

88 console = Console(file=StringIO()) 

89 console.print(self.__prog.get_renderable()) 

90 logger.info(console.file.getvalue().rstrip()) 

91 

92 await self.stream.send(item) 

93 

94 async def send_eof(self): 

95 await self.stream.send_eof() 

96 

97 async def aclose(self): 

98 await self.stream.aclose()