Coverage for /home/fedora/jumpstarter/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py: 18%
392 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-05 20:29 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-05 20:29 +0000
1import asyncio
2import logging
3import pathlib
4from enum import IntEnum
5from typing import Optional, Set, Tuple
7from opendal import Operator
10class Opcode(IntEnum):
11 RRQ = 1
12 WRQ = 2
13 DATA = 3
14 ACK = 4
15 ERROR = 5
16 OACK = 6
19class TftpErrorCode(IntEnum):
20 NOT_DEFINED = 0
21 FILE_NOT_FOUND = 1
22 ACCESS_VIOLATION = 2
23 DISK_FULL = 3
24 ILLEGAL_OPERATION = 4
25 UNKNOWN_TID = 5
26 FILE_EXISTS = 6
27 NO_SUCH_USER = 7
30class TftpServer:
31 """
32 TFTP Server that handles read requests (RRQ).
33 """
35 def __init__(
36 self,
37 host: str,
38 port: int,
39 operator: Operator,
40 block_size: int = 512,
41 timeout: float = 5.0,
42 retries: int = 3,
43 logger: logging.Logger | None = None,
44 ):
45 self.host = host
46 self.port = port
47 self.operator = operator
48 self.block_size = block_size
49 self.timeout = timeout
50 self.retries = retries
51 self.active_transfers: Set["TftpTransfer"] = set()
52 self.shutdown_event = asyncio.Event()
53 self.transport: Optional[asyncio.DatagramTransport] = None
54 self.protocol: Optional["TftpServerProtocol"] = None
56 if logger is not None:
57 self.logger = logger.getChild(self.__class__.__name__)
58 else:
59 self.logger = logging.getLogger(self.__class__.__name__)
61 self.ready_event = asyncio.Event()
63 @property
64 def address(self) -> Optional[Tuple[str, int]]:
65 """Get the server's bound address and port."""
66 if self.transport:
67 return self.transport.get_extra_info("socket").getsockname()
68 return None
70 async def start(self):
71 self.logger.info(f"Starting TFTP server on {self.host}:{self.port}")
72 loop = asyncio.get_running_loop()
74 self.ready_event.set()
75 self.transport, self.protocol = await loop.create_datagram_endpoint(
76 lambda: TftpServerProtocol(self), local_addr=(self.host, self.port)
77 )
79 try:
80 await self.shutdown_event.wait()
81 except asyncio.CancelledError:
82 pass
83 finally:
84 self.logger.info("TFTP server shutting down")
85 await self._cleanup()
87 async def _cleanup(self):
88 self.logger.info("Cleaning up TFTP server resources")
90 # Cancel all active transfers
91 cleanup_tasks = [transfer.cleanup() for transfer in self.active_transfers.copy()]
92 if cleanup_tasks:
93 await asyncio.gather(*cleanup_tasks, return_exceptions=True)
95 # Close the main transport
96 if self.transport:
97 self.transport.close()
98 self.transport = None
100 self.logger.info("TFTP server cleanup completed")
102 async def shutdown(self):
103 self.logger.info("Shutdown signal received for TFTP server")
104 self.shutdown_event.set()
106 def register_transfer(self, transfer: "TftpTransfer"):
107 self.active_transfers.add(transfer)
108 self.logger.debug(f"Registered transfer: {transfer}")
110 def unregister_transfer(self, transfer: "TftpTransfer"):
111 self.active_transfers.discard(transfer)
112 self.logger.debug(f"Unregistered transfer: {transfer}")
115class TftpServerProtocol(asyncio.DatagramProtocol):
116 """
117 Protocol for handling incoming TFTP requests.
118 """
120 def __init__(self, server: TftpServer):
121 self.server = server
122 self.transport: Optional[asyncio.DatagramTransport] = None
123 self.logger = server.logger.getChild(self.__class__.__name__)
125 def connection_made(self, transport: asyncio.DatagramTransport):
126 self.transport = transport
127 self.logger.debug("Server protocol connection established")
129 def connection_lost(self, exc: Optional[Exception]):
130 self.logger.info("TFTP server protocol connection lost")
132 def datagram_received(self, data: bytes, addr: Tuple[str, int]):
133 self.logger.debug(f"Received datagram from {addr}")
134 if len(data) < 4:
135 self.logger.warning(f"Received malformed packet from {addr}")
136 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Malformed packet")
137 return
139 if self.server.shutdown_event.is_set():
140 self.logger.debug(f"Ignoring packet from {addr} - server is shutting down")
141 return
143 try:
144 opcode = Opcode(int.from_bytes(data[0:2], "big"))
145 except ValueError:
146 self.logger.error(f"Unknown opcode from {addr}")
147 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode")
148 return
150 self.logger.debug(f"Received opcode {opcode.name} from {addr}")
152 if opcode == Opcode.RRQ:
153 asyncio.create_task(self._handle_read_request(data, addr))
154 else:
155 self.logger.warning(f"Unsupported opcode {opcode} from {addr}")
156 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported operation")
158 async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
159 try:
160 filename, mode, options = self._parse_request(data)
161 self.logger.info(f"RRQ from {addr}: '{filename}' in mode '{mode}' with options {options}")
163 if not self._validate_mode(mode, addr):
164 return
166 resolved_path = await self._resolve_and_validate_path(filename, addr)
167 if not resolved_path:
168 return
170 negotiated_options, blksize, timeout = self._negotiate_options(options)
171 self.logger.info(f"Negotiated options: {negotiated_options}")
172 await self._start_transfer(resolved_path, addr, blksize, timeout, negotiated_options)
173 except Exception as e:
174 self.logger.error(f"Error handling RRQ from {addr}: {e}")
175 self._send_error(addr, TftpErrorCode.NOT_DEFINED, str(e))
177 def _send_oack(self, addr: Tuple[str, int], options: dict):
178 """Send Option Acknowledgment (OACK) packet."""
179 oack_data = Opcode.OACK.to_bytes(2, "big")
180 for opt_name, opt_value in options.items():
181 oack_data += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8")
183 if self.transport:
184 self.transport.sendto(oack_data, addr)
185 self.logger.debug(f"Sent OACK to {addr} with options {options}")
187 def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str):
188 error_packet = (
189 Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00"
190 )
191 if self.transport:
192 self.transport.sendto(error_packet, addr)
193 self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}")
195 def _parse_request(self, data: bytes) -> Tuple[str, str, dict]:
196 parts = data[2:].split(b"\x00")
197 if len(parts) < 2:
198 raise ValueError("Invalid RRQ format")
200 filename = parts[0].decode("utf-8")
201 if len(filename) > 255: # RFC 1350 doesn't specify a limit
202 raise ValueError("Filename too long")
203 if not all(c.isprintable() and c not in '<>:"/\\|?*' for c in filename):
204 raise ValueError("Invalid characters in filename")
205 if "\x00" in filename:
206 raise ValueError("Null byte in filename")
207 mode = parts[1].decode("utf-8").lower()
208 options = self._parse_options(parts[2:])
210 return filename, mode, options
212 def _parse_options(self, option_parts: list) -> dict:
213 options = {}
214 i = 0
215 while i < len(option_parts) - 1:
216 try:
217 opt_name = option_parts[i].decode("utf-8").lower()
218 opt_value = option_parts[i + 1].decode("utf-8")
219 options[opt_name] = opt_value
220 i += 2
221 except Exception:
222 break
223 return options
225 def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool:
226 if mode not in ("netascii", "octet"):
227 self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}")
228 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported transfer mode")
229 return False
230 return True
232 async def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[str]:
233 try:
234 stat = await self.server.operator.stat(filename)
235 except FileNotFoundError:
236 self.logger.error(f"File not found: {filename}")
237 self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found")
238 return None
240 if not stat.mode.is_file():
241 self.logger.error(f"Not a file: {filename}")
242 self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "Not a file")
243 return None
245 return filename
247 def _negotiate_block_size(self, requested_blksize: Optional[str]) -> int:
248 if requested_blksize is None:
249 return self.server.block_size
251 try:
252 blksize = int(requested_blksize)
253 if 512 <= blksize <= 65464:
254 return blksize
255 else:
256 self.logger.warning(
257 f"Requested block size {blksize} out of range (512-65464), using default: {self.server.block_size}"
258 )
259 return self.server.block_size
260 except ValueError:
261 self.logger.warning(
262 f"Invalid block size value '{requested_blksize}', using default: {self.server.block_size}"
263 )
264 return self.server.block_size
266 def _negotiate_timeout(self, requested_timeout: Optional[str]) -> float:
267 if requested_timeout is None:
268 return self.server.timeout
270 try:
271 timeout = int(requested_timeout)
272 if 1 <= timeout <= 255:
273 return float(timeout)
274 else:
275 self.logger.warning(
276 f"Timeout value {timeout} out of range (1-255), using default: {self.server.timeout}"
277 )
278 return self.server.timeout
279 except ValueError:
280 self.logger.warning(f"Invalid timeout value '{requested_timeout}', using default: {self.server.timeout}")
281 return self.server.timeout
283 def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]:
284 negotiated = {}
285 blksize = self.server.block_size
286 timeout = self.server.timeout
288 if "blksize" in options:
289 requested = options["blksize"]
290 blksize = self._negotiate_block_size(requested)
291 negotiated["blksize"] = blksize
293 if "timeout" in options:
294 requested = options["timeout"]
295 timeout = self._negotiate_timeout(requested)
296 negotiated["timeout"] = int(timeout)
298 return negotiated, blksize, timeout
300 async def _start_transfer(
301 self, filepath: str, addr: Tuple[str, int], blksize: int, timeout: float, negotiated_options: dict
302 ):
303 transfer = TftpReadTransfer(
304 server=self.server,
305 filepath=filepath,
306 client_addr=addr,
307 block_size=blksize,
308 timeout=timeout,
309 retries=self.server.retries,
310 negotiated_options=negotiated_options,
311 )
312 self.server.register_transfer(transfer)
313 asyncio.create_task(transfer.start())
316def is_subpath(path: pathlib.Path, root: pathlib.Path) -> bool:
317 try:
318 path.relative_to(root)
319 return True
320 except ValueError:
321 return False
324class TftpTransfer:
325 """
326 Base class for TFTP transfers.
327 """
329 def __init__(
330 self,
331 server: TftpServer,
332 filepath: pathlib.Path,
333 client_addr: Tuple[str, int],
334 block_size: int,
335 timeout: float,
336 retries: int,
337 ):
338 self.server = server
339 self.filepath = filepath
340 self.client_addr = client_addr
341 self.block_size = block_size
342 self.timeout = timeout
343 self.retries = retries
344 self.transport: Optional[asyncio.DatagramTransport] = None
345 self.protocol: Optional["TftpTransferProtocol"] = None
346 self.cleanup_task: Optional[asyncio.Task] = None
347 self.logger = server.logger.getChild(self.__class__.__name__)
349 async def start(self):
350 """Start the transfer."""
351 raise NotImplementedError
353 async def cleanup(self):
354 """Clean up transfer resources."""
355 self.logger.info(f"Cleaning up transfer for {self.client_addr}")
356 if self.transport:
357 self.transport.close()
358 self.transport = None
359 self.server.unregister_transfer(self)
362class TftpReadTransfer(TftpTransfer):
363 def __init__(
364 self,
365 server: TftpServer,
366 filepath: str,
367 client_addr: Tuple[str, int],
368 block_size: int,
369 timeout: float,
370 retries: int,
371 negotiated_options: Optional[dict] = None,
372 ):
373 super().__init__(
374 server=server,
375 filepath=filepath,
376 client_addr=client_addr,
377 block_size=block_size,
378 timeout=timeout,
379 retries=retries,
380 )
381 self.block_num = 0
382 self.ack_received = asyncio.Event()
383 self.last_ack = 0
384 self.oack_confirmed = False
385 self.negotiated_options = negotiated_options
386 self.current_packet: Optional[bytes] = None
388 async def start(self):
389 self.logger.info(f"Starting read transfer of '{self.filepath}' to {self.client_addr}")
391 if not await self._initialize_transfer():
392 return
394 try:
395 # if no options were negotiated, we can start sending data immediately
396 if not self.negotiated_options:
397 self.oack_confirmed = True
399 await self._perform_transfer()
400 except Exception as e:
401 self.logger.error(f"Error during read transfer: {e}")
402 finally:
403 await self.cleanup()
405 async def _initialize_transfer(self) -> bool:
406 loop = asyncio.get_running_loop()
408 self.transport, self.protocol = await loop.create_datagram_endpoint(
409 lambda: TftpTransferProtocol(self), local_addr=("0.0.0.0", 0), remote_addr=self.client_addr
410 )
411 local_addr = self.transport.get_extra_info("sockname")
412 self.logger.debug(f"Transfer bound to local {local_addr}")
414 # Only send OACK if we have non-default options to negotiate
415 if self.negotiated_options and (
416 self.negotiated_options["blksize"] != 512 or self.negotiated_options["timeout"] != self.server.timeout
417 ):
418 oack_packet = self._create_oack_packet()
419 if not await self._send_with_retries(oack_packet, is_oack=True):
420 self.logger.error("Failed to get acknowledgment for OACK")
421 return False
423 self.block_num = 1
424 return True
426 async def _perform_transfer(self):
427 async with await self.server.operator.open(self.filepath, "rb") as f:
428 while True:
429 if self.server.shutdown_event.is_set():
430 self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}")
431 break
433 # read a full block or until EOF
434 data = bytearray()
435 while len(data) < self.block_size:
436 chunk = await f.read(size=self.block_size - len(data))
437 if not chunk: # EOF reached
438 break
439 data.extend(chunk)
441 # send the data (converted to bytes)
442 if not await self._handle_data_block(bytes(data)):
443 break
445 async def _handle_data_block(self, data: bytes) -> bool:
446 """
447 Handle sending a block of data to the client.
448 Returns False if transfer should stop, True if it should continue.
449 """
450 if not data and self.block_num == 1:
451 # Empty file case
452 packet = self._create_data_packet(b"")
453 await self._send_with_retries(packet)
454 return False
455 elif data:
456 packet = self._create_data_packet(data)
457 success = await self._send_with_retries(packet)
458 if not success:
459 self.logger.error(f"Failed to send block {self.block_num} to {self.client_addr}")
460 return False
462 self.logger.debug(f"Block {self.block_num} sent successfully")
463 self.block_num += 1
465 # wrap block number around if it exceeds 16 bits
466 self.block_num %= 65536
468 if len(data) < self.block_size:
469 self.logger.info(f"Final block {self.block_num - 1} sent")
470 return False
471 return True
472 else:
473 # EOF reached
474 packet = self._create_data_packet(b"")
475 success = await self._send_with_retries(packet)
476 if not success:
477 self.logger.error(f"Failed to send final block {self.block_num}")
478 else:
479 self.logger.info(f"Transfer complete, final block {self.block_num}")
480 return False
482 def _create_oack_packet(self) -> bytes:
483 packet = Opcode.OACK.to_bytes(2, "big")
484 for opt_name, opt_value in self.negotiated_options.items():
485 packet += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8")
486 return packet
488 def _create_data_packet(self, data: bytes) -> bytes:
489 return Opcode.DATA.to_bytes(2, "big") + self.block_num.to_bytes(2, "big") + data
491 def _send_packet(self, packet: bytes):
492 self.transport.sendto(packet)
493 if packet[0:2] == Opcode.DATA.to_bytes(2, "big"):
494 block = int.from_bytes(packet[2:4], "big")
495 data_length = len(packet) - 4
496 self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}")
497 elif packet[0:2] == Opcode.OACK.to_bytes(2, "big"):
498 self.logger.debug(f"Sent OACK to {self.client_addr}")
500 async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool:
501 self.current_packet = packet
502 expected_block = 0 if is_oack else self.block_num
504 for attempt in range(1, self.retries + 1):
505 try:
506 self._send_packet(packet)
507 self.logger.debug(
508 f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, waiting for ACK (Attempt {attempt})"
509 )
510 self.ack_received.clear()
511 await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout)
513 if self.last_ack == expected_block:
514 self.logger.debug(f"ACK received for block {expected_block}")
515 return True
516 else:
517 self.logger.warning(f"Received wrong ACK: expected {expected_block}, got {self.last_ack}")
519 except asyncio.TimeoutError:
520 self.logger.warning(f"Timeout waiting for ACK of block {expected_block} (Attempt {attempt})")
522 return False
524 def handle_ack(self, block_num: int):
525 self.logger.debug(f"Received ACK for block {block_num} from {self.client_addr}")
527 # special handling for OACK acknowledgment
528 if not self.oack_confirmed and self.negotiated_options and block_num == 0:
529 self.oack_confirmed = True
530 self.last_ack = block_num
531 self.ack_received.set()
532 return
534 if block_num == self.block_num:
535 self.last_ack = block_num
536 self.ack_received.set()
537 elif block_num == self.block_num - 1:
538 self.logger.warning(f"Duplicate ACK for block {block_num} received, resending block {self.block_num}")
539 self.transport.sendto(self.current_packet)
540 else:
541 self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}")
544class TftpTransferProtocol(asyncio.DatagramProtocol):
545 """
546 Protocol for handling ACKs during a TFTP transfer.
547 """
549 def __init__(self, transfer: TftpReadTransfer):
550 self.transfer = transfer
551 self.logger = transfer.logger.getChild(self.__class__.__name__)
553 def connection_made(self, transport: asyncio.DatagramTransport):
554 self.transfer.transport = transport
555 local_addr = transport.get_extra_info("sockname")
556 self.logger.debug(f"Transfer protocol connection established on {local_addr} for {self.transfer.client_addr}")
558 def datagram_received(self, data: bytes, addr: Tuple[str, int]):
559 self.logger.debug(f"Received datagram from {addr}")
560 if addr != self.transfer.client_addr:
561 self.logger.warning(f"Ignoring packet from unknown source {addr}")
562 return
564 if len(data) < 4:
565 self.logger.warning(f"Received malformed packet from {addr}")
566 return
568 try:
569 opcode = Opcode(int.from_bytes(data[0:2], "big"))
570 except ValueError:
571 self.logger.error(f"Unknown opcode from {addr}")
572 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode")
573 return
575 if opcode == Opcode.ACK:
576 block_num = int.from_bytes(data[2:4], "big")
577 self.logger.debug(f"Received ACK for block {block_num} from {addr}")
578 self.transfer.handle_ack(block_num)
579 else:
580 self.logger.warning(f"Unexpected opcode {opcode} from {addr}")
581 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unexpected opcode")
583 def error_received(self, exc):
584 self.logger.error(f"Error received: {exc}")
586 def connection_lost(self, exc):
587 self.logger.debug(f"Connection closed for transfer to {self.transfer.client_addr}")
589 def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str):
590 error_packet = (
591 Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00"
592 )
593 if self.transfer.transport:
594 self.transfer.transport.sendto(error_packet)
595 self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}")