Coverage for /home/fedora/jumpstarter/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py: 50%

107 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-05 20:29 +0000

1import asyncio 

2import os 

3import socket 

4import threading 

5from dataclasses import dataclass, field 

6from typing import Optional 

7 

8from jumpstarter_driver_opendal.driver import Opendal 

9 

10from jumpstarter_driver_tftp.server import TftpServer 

11 

12from jumpstarter.driver import Driver, export 

13 

14 

15class TftpError(Exception): 

16 """Base exception for TFTP server errors""" 

17 

18 pass 

19 

20 

21class ServerNotRunning(TftpError): 

22 """Server is not running""" 

23 

24 pass 

25 

26 

27@dataclass(kw_only=True) 

28class Tftp(Driver): 

29 """TFTP Server driver for Jumpstarter 

30 

31 This driver implements a TFTP read-only server. 

32 

33 Attributes: 

34 root_dir (str): Root directory for the TFTP server. Defaults to "/var/lib/tftpboot" 

35 host (str): IP address to bind the server to. If empty, will use the default route interface 

36 port (int): Port number to listen on. Defaults to 69 (standard TFTP port) 

37 """ 

38 

39 root_dir: str = "/var/lib/tftpboot" 

40 host: str = field(default="") 

41 port: int = 69 

42 server: Optional["TftpServer"] = field(init=False, default=None) 

43 server_thread: Optional[threading.Thread] = field(init=False, default=None) 

44 _shutdown_event: threading.Event = field(init=False, default_factory=threading.Event) 

45 _loop_ready: threading.Event = field(init=False, default_factory=threading.Event) 

46 _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) 

47 

48 def __post_init__(self): 

49 if hasattr(super(), "__post_init__"): 

50 super().__post_init__() 

51 

52 os.makedirs(self.root_dir, exist_ok=True) 

53 

54 self.children["storage"] = Opendal(scheme="fs", kwargs={"root": self.root_dir}) 

55 self.storage = self.children["storage"] 

56 

57 if self.host == "": 

58 self.host = self.get_default_ip() 

59 

60 def get_default_ip(self): 

61 """Get the IP address of the default route interface""" 

62 try: 

63 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: 

64 s.connect(("8.8.8.8", 80)) 

65 return s.getsockname()[0] 

66 except Exception: 

67 self.logger.warning("Could not determine default IP address, falling back to 0.0.0.0") 

68 return "0.0.0.0" 

69 

70 @classmethod 

71 def client(cls) -> str: 

72 return "jumpstarter_driver_tftp.client.TftpServerClient" 

73 

74 def _start_server(self): 

75 self._loop = asyncio.new_event_loop() 

76 asyncio.set_event_loop(self._loop) 

77 self.server = TftpServer( 

78 host=self.host, 

79 port=self.port, 

80 operator=self.children["storage"]._operator, 

81 logger=self.logger, 

82 ) 

83 try: 

84 self._loop_ready.set() 

85 self._loop.run_until_complete(self._run_server()) 

86 except Exception as e: 

87 self.logger.error(f"Error running TFTP server: {e}") 

88 finally: 

89 try: 

90 self._loop.run_until_complete(self._loop.shutdown_asyncgens()) 

91 self._loop.close() 

92 except Exception as e: 

93 self.logger.error(f"Error during event loop cleanup: {e}") 

94 self._loop = None 

95 self.logger.info("TFTP server thread completed") 

96 

97 async def _run_server(self): 

98 try: 

99 server_task = asyncio.create_task(self.server.start()) 

100 await asyncio.gather(server_task, self._wait_for_shutdown()) 

101 except asyncio.CancelledError: 

102 self.logger.info("Server task cancelled") 

103 raise 

104 

105 async def _wait_for_shutdown(self): 

106 while not self._shutdown_event.is_set(): 

107 await asyncio.sleep(0.1) 

108 self.logger.info("Shutdown event detected") 

109 if self.server is not None: 

110 await self.server.shutdown() 

111 

112 @export 

113 def start(self): 

114 """Start the TFTP server. 

115 

116 The server will start listening for incoming TFTP requests on the configured 

117 host and port. If the server is already running, a warning will be logged. 

118 

119 Raises: 

120 TftpError: If the server fails to start or times out during initialization 

121 """ 

122 if self.server_thread is not None and self.server_thread.is_alive(): 

123 self.logger.warning("TFTP server is already running") 

124 return 

125 

126 self._shutdown_event.clear() 

127 self._loop_ready.clear() 

128 

129 self.server_thread = threading.Thread(target=self._start_server, daemon=True) 

130 self.server_thread.start() 

131 

132 if not self._loop_ready.wait(timeout=5.0): 

133 self.logger.error("Timeout waiting for event loop to be ready") 

134 self.server_thread = None 

135 raise TftpError("Failed to start TFTP server - event loop initialization timeout") 

136 

137 self.logger.info(f"TFTP server started on {self.host}:{self.port}") 

138 

139 @export 

140 def stop(self): 

141 """Stop the TFTP server. 

142 

143 Initiates a graceful shutdown of the server and waits for all active transfers 

144 to complete. If the server is not running, a warning will be logged. 

145 """ 

146 if self.server_thread is None or not self.server_thread.is_alive(): 

147 self.logger.warning("stop called - TFTP server is not running") 

148 return 

149 

150 self.logger.info("Initiating TFTP server shutdown") 

151 self._shutdown_event.set() 

152 self.server_thread.join(timeout=10) 

153 if self.server_thread.is_alive(): 

154 self.logger.error("Failed to stop TFTP server thread within timeout") 

155 else: 

156 self.logger.info("TFTP server stopped successfully") 

157 self.server_thread = None 

158 

159 @export 

160 def get_host(self) -> str: 

161 """Get the host address the server is bound to. 

162 

163 Returns: 

164 str: The IP address or hostname 

165 """ 

166 return self.host 

167 

168 @export 

169 def get_port(self) -> int: 

170 """Get the port number the server is listening on. 

171 

172 Returns: 

173 int: The port number 

174 """ 

175 return self.port 

176 

177 def close(self): 

178 if self.server_thread is not None: 

179 self.stop() 

180 super().close()