Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py: 48%

99 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 15:50 +0200

1import asyncio 

2import os 

3import threading 

4from dataclasses import dataclass, field 

5from typing import Optional 

6 

7from jumpstarter_driver_opendal.driver import Opendal 

8 

9from jumpstarter_driver_tftp.server import TftpServer 

10 

11from jumpstarter.common.ipaddr import get_ip_address 

12from jumpstarter.driver import Driver, export 

13 

14 

15class TftpError(Exception): 

16 """Base exception for TFTP server errors""" 

17 

18 pass 

19 

20 

21class ServerNotRunning(TftpError): 

22 """Server is not running""" 

23 

24 pass 

25 

26 

27@dataclass(kw_only=True) 

28class Tftp(Driver): 

29 """TFTP Server driver for Jumpstarter 

30 

31 This driver implements a TFTP read-only server. 

32 

33 Attributes: 

34 root_dir (str): Root directory for the TFTP server. Defaults to "/var/lib/tftpboot" 

35 host (str): IP address to bind the server to. If empty, will use the default route interface 

36 port (int): Port number to listen on. Defaults to 69 (standard TFTP port) 

37 """ 

38 

39 root_dir: str = "/var/lib/tftpboot" 

40 host: str = field(default="") 

41 port: int = 69 

42 server: Optional["TftpServer"] = field(init=False, default=None) 

43 server_thread: Optional[threading.Thread] = field(init=False, default=None) 

44 _shutdown_event: threading.Event = field(init=False, default_factory=threading.Event) 

45 _loop_ready: threading.Event = field(init=False, default_factory=threading.Event) 

46 _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) 

47 

48 def __post_init__(self): 

49 if hasattr(super(), "__post_init__"): 

50 super().__post_init__() 

51 

52 os.makedirs(self.root_dir, exist_ok=True) 

53 

54 self.children["storage"] = Opendal(scheme="fs", kwargs={"root": self.root_dir}) 

55 self.storage = self.children["storage"] 

56 

57 if self.host == "": 

58 self.host = get_ip_address(logger=self.logger) 

59 

60 @classmethod 

61 def client(cls) -> str: 

62 return "jumpstarter_driver_tftp.client.TftpServerClient" 

63 

64 def _start_server(self): 

65 self._loop = asyncio.new_event_loop() 

66 asyncio.set_event_loop(self._loop) 

67 self.server = TftpServer( 

68 host=self.host, 

69 port=self.port, 

70 operator=self.children["storage"]._operator, 

71 logger=self.logger, 

72 ) 

73 try: 

74 self._loop_ready.set() 

75 self._loop.run_until_complete(self._run_server()) 

76 except Exception as e: 

77 self.logger.error(f"Error running TFTP server: {e}") 

78 finally: 

79 try: 

80 self._loop.run_until_complete(self._loop.shutdown_asyncgens()) 

81 self._loop.close() 

82 except Exception as e: 

83 self.logger.error(f"Error during event loop cleanup: {e}") 

84 self._loop = None 

85 self.logger.info("TFTP server thread completed") 

86 

87 async def _run_server(self): 

88 try: 

89 server_task = asyncio.create_task(self.server.start()) 

90 await asyncio.gather(server_task, self._wait_for_shutdown()) 

91 except asyncio.CancelledError: 

92 self.logger.info("Server task cancelled") 

93 raise 

94 

95 async def _wait_for_shutdown(self): 

96 while not self._shutdown_event.is_set(): 

97 await asyncio.sleep(0.1) 

98 self.logger.info("Shutdown event detected") 

99 if self.server is not None: 

100 await self.server.shutdown() 

101 

102 @export 

103 def start(self): 

104 """Start the TFTP server. 

105 

106 The server will start listening for incoming TFTP requests on the configured 

107 host and port. If the server is already running, a warning will be logged. 

108 

109 Raises: 

110 TftpError: If the server fails to start or times out during initialization 

111 """ 

112 if self.server_thread is not None and self.server_thread.is_alive(): 

113 self.logger.warning("TFTP server is already running") 

114 return 

115 

116 self._shutdown_event.clear() 

117 self._loop_ready.clear() 

118 

119 self.server_thread = threading.Thread(target=self._start_server, daemon=True) 

120 self.server_thread.start() 

121 

122 if not self._loop_ready.wait(timeout=5.0): 

123 self.logger.error("Timeout waiting for event loop to be ready") 

124 self.server_thread = None 

125 raise TftpError("Failed to start TFTP server - event loop initialization timeout") 

126 

127 self.logger.info(f"TFTP server started on {self.host}:{self.port}") 

128 

129 @export 

130 def stop(self): 

131 """Stop the TFTP server. 

132 

133 Initiates a graceful shutdown of the server and waits for all active transfers 

134 to complete. If the server is not running, a warning will be logged. 

135 """ 

136 if self.server_thread is None or not self.server_thread.is_alive(): 

137 self.logger.warning("stop called - TFTP server is not running") 

138 return 

139 

140 self.logger.info("Initiating TFTP server shutdown") 

141 self._shutdown_event.set() 

142 self.server_thread.join(timeout=10) 

143 if self.server_thread.is_alive(): 

144 self.logger.error("Failed to stop TFTP server thread within timeout") 

145 else: 

146 self.logger.info("TFTP server stopped successfully") 

147 self.server_thread = None 

148 

149 @export 

150 def get_host(self) -> str: 

151 """Get the host address the server is bound to. 

152 

153 Returns: 

154 str: The IP address or hostname 

155 """ 

156 return self.host 

157 

158 @export 

159 def get_port(self) -> int: 

160 """Get the port number the server is listening on. 

161 

162 Returns: 

163 int: The port number 

164 """ 

165 return self.port 

166 

167 def close(self): 

168 if self.server_thread is not None: 

169 self.stop() 

170 super().close()