Coverage for /home/fedora/jumpstarter/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py: 50%
107 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-05 20:29 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-05 20:29 +0000
1import asyncio
2import os
3import socket
4import threading
5from dataclasses import dataclass, field
6from typing import Optional
8from jumpstarter_driver_opendal.driver import Opendal
10from jumpstarter_driver_tftp.server import TftpServer
12from jumpstarter.driver import Driver, export
15class TftpError(Exception):
16 """Base exception for TFTP server errors"""
18 pass
21class ServerNotRunning(TftpError):
22 """Server is not running"""
24 pass
27@dataclass(kw_only=True)
28class Tftp(Driver):
29 """TFTP Server driver for Jumpstarter
31 This driver implements a TFTP read-only server.
33 Attributes:
34 root_dir (str): Root directory for the TFTP server. Defaults to "/var/lib/tftpboot"
35 host (str): IP address to bind the server to. If empty, will use the default route interface
36 port (int): Port number to listen on. Defaults to 69 (standard TFTP port)
37 """
39 root_dir: str = "/var/lib/tftpboot"
40 host: str = field(default="")
41 port: int = 69
42 server: Optional["TftpServer"] = field(init=False, default=None)
43 server_thread: Optional[threading.Thread] = field(init=False, default=None)
44 _shutdown_event: threading.Event = field(init=False, default_factory=threading.Event)
45 _loop_ready: threading.Event = field(init=False, default_factory=threading.Event)
46 _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None)
48 def __post_init__(self):
49 if hasattr(super(), "__post_init__"):
50 super().__post_init__()
52 os.makedirs(self.root_dir, exist_ok=True)
54 self.children["storage"] = Opendal(scheme="fs", kwargs={"root": self.root_dir})
55 self.storage = self.children["storage"]
57 if self.host == "":
58 self.host = self.get_default_ip()
60 def get_default_ip(self):
61 """Get the IP address of the default route interface"""
62 try:
63 with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
64 s.connect(("8.8.8.8", 80))
65 return s.getsockname()[0]
66 except Exception:
67 self.logger.warning("Could not determine default IP address, falling back to 0.0.0.0")
68 return "0.0.0.0"
70 @classmethod
71 def client(cls) -> str:
72 return "jumpstarter_driver_tftp.client.TftpServerClient"
74 def _start_server(self):
75 self._loop = asyncio.new_event_loop()
76 asyncio.set_event_loop(self._loop)
77 self.server = TftpServer(
78 host=self.host,
79 port=self.port,
80 operator=self.children["storage"]._operator,
81 logger=self.logger,
82 )
83 try:
84 self._loop_ready.set()
85 self._loop.run_until_complete(self._run_server())
86 except Exception as e:
87 self.logger.error(f"Error running TFTP server: {e}")
88 finally:
89 try:
90 self._loop.run_until_complete(self._loop.shutdown_asyncgens())
91 self._loop.close()
92 except Exception as e:
93 self.logger.error(f"Error during event loop cleanup: {e}")
94 self._loop = None
95 self.logger.info("TFTP server thread completed")
97 async def _run_server(self):
98 try:
99 server_task = asyncio.create_task(self.server.start())
100 await asyncio.gather(server_task, self._wait_for_shutdown())
101 except asyncio.CancelledError:
102 self.logger.info("Server task cancelled")
103 raise
105 async def _wait_for_shutdown(self):
106 while not self._shutdown_event.is_set():
107 await asyncio.sleep(0.1)
108 self.logger.info("Shutdown event detected")
109 if self.server is not None:
110 await self.server.shutdown()
112 @export
113 def start(self):
114 """Start the TFTP server.
116 The server will start listening for incoming TFTP requests on the configured
117 host and port. If the server is already running, a warning will be logged.
119 Raises:
120 TftpError: If the server fails to start or times out during initialization
121 """
122 if self.server_thread is not None and self.server_thread.is_alive():
123 self.logger.warning("TFTP server is already running")
124 return
126 self._shutdown_event.clear()
127 self._loop_ready.clear()
129 self.server_thread = threading.Thread(target=self._start_server, daemon=True)
130 self.server_thread.start()
132 if not self._loop_ready.wait(timeout=5.0):
133 self.logger.error("Timeout waiting for event loop to be ready")
134 self.server_thread = None
135 raise TftpError("Failed to start TFTP server - event loop initialization timeout")
137 self.logger.info(f"TFTP server started on {self.host}:{self.port}")
139 @export
140 def stop(self):
141 """Stop the TFTP server.
143 Initiates a graceful shutdown of the server and waits for all active transfers
144 to complete. If the server is not running, a warning will be logged.
145 """
146 if self.server_thread is None or not self.server_thread.is_alive():
147 self.logger.warning("stop called - TFTP server is not running")
148 return
150 self.logger.info("Initiating TFTP server shutdown")
151 self._shutdown_event.set()
152 self.server_thread.join(timeout=10)
153 if self.server_thread.is_alive():
154 self.logger.error("Failed to stop TFTP server thread within timeout")
155 else:
156 self.logger.info("TFTP server stopped successfully")
157 self.server_thread = None
159 @export
160 def get_host(self) -> str:
161 """Get the host address the server is bound to.
163 Returns:
164 str: The IP address or hostname
165 """
166 return self.host
168 @export
169 def get_port(self) -> int:
170 """Get the port number the server is listening on.
172 Returns:
173 int: The port number
174 """
175 return self.port
177 def close(self):
178 if self.server_thread is not None:
179 self.stop()
180 super().close()