Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py: 18%

392 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 15:50 +0200

1import asyncio 

2import logging 

3import pathlib 

4from enum import IntEnum 

5from typing import Optional, Set, Tuple 

6 

7from opendal import AsyncOperator 

8 

9 

10class Opcode(IntEnum): 

11 RRQ = 1 

12 WRQ = 2 

13 DATA = 3 

14 ACK = 4 

15 ERROR = 5 

16 OACK = 6 

17 

18 

19class TftpErrorCode(IntEnum): 

20 NOT_DEFINED = 0 

21 FILE_NOT_FOUND = 1 

22 ACCESS_VIOLATION = 2 

23 DISK_FULL = 3 

24 ILLEGAL_OPERATION = 4 

25 UNKNOWN_TID = 5 

26 FILE_EXISTS = 6 

27 NO_SUCH_USER = 7 

28 

29 

30class TftpServer: 

31 """ 

32 TFTP Server that handles read requests (RRQ). 

33 """ 

34 

35 def __init__( 

36 self, 

37 host: str, 

38 port: int, 

39 operator: AsyncOperator, 

40 block_size: int = 512, 

41 timeout: float = 5.0, 

42 retries: int = 3, 

43 logger: logging.Logger | None = None, 

44 ): 

45 self.host = host 

46 self.port = port 

47 self.operator = operator 

48 self.block_size = block_size 

49 self.timeout = timeout 

50 self.retries = retries 

51 self.active_transfers: Set["TftpTransfer"] = set() 

52 self.shutdown_event = asyncio.Event() 

53 self.transport: Optional[asyncio.DatagramTransport] = None 

54 self.protocol: Optional["TftpServerProtocol"] = None 

55 

56 if logger is not None: 

57 self.logger = logger.getChild(self.__class__.__name__) 

58 else: 

59 self.logger = logging.getLogger(self.__class__.__name__) 

60 

61 self.ready_event = asyncio.Event() 

62 

63 @property 

64 def address(self) -> Optional[Tuple[str, int]]: 

65 """Get the server's bound address and port.""" 

66 if self.transport: 

67 return self.transport.get_extra_info("socket").getsockname() 

68 return None 

69 

70 async def start(self): 

71 self.logger.info(f"Starting TFTP server on {self.host}:{self.port}") 

72 loop = asyncio.get_running_loop() 

73 

74 self.ready_event.set() 

75 self.transport, self.protocol = await loop.create_datagram_endpoint( 

76 lambda: TftpServerProtocol(self), local_addr=(self.host, self.port) 

77 ) 

78 

79 try: 

80 await self.shutdown_event.wait() 

81 except asyncio.CancelledError: 

82 pass 

83 finally: 

84 self.logger.info("TFTP server shutting down") 

85 await self._cleanup() 

86 

87 async def _cleanup(self): 

88 self.logger.info("Cleaning up TFTP server resources") 

89 

90 # Cancel all active transfers 

91 cleanup_tasks = [transfer.cleanup() for transfer in self.active_transfers.copy()] 

92 if cleanup_tasks: 

93 await asyncio.gather(*cleanup_tasks, return_exceptions=True) 

94 

95 # Close the main transport 

96 if self.transport: 

97 self.transport.close() 

98 self.transport = None 

99 

100 self.logger.info("TFTP server cleanup completed") 

101 

102 async def shutdown(self): 

103 self.logger.info("Shutdown signal received for TFTP server") 

104 self.shutdown_event.set() 

105 

106 def register_transfer(self, transfer: "TftpTransfer"): 

107 self.active_transfers.add(transfer) 

108 self.logger.debug(f"Registered transfer: {transfer}") 

109 

110 def unregister_transfer(self, transfer: "TftpTransfer"): 

111 self.active_transfers.discard(transfer) 

112 self.logger.debug(f"Unregistered transfer: {transfer}") 

113 

114 

115class TftpServerProtocol(asyncio.DatagramProtocol): 

116 """ 

117 Protocol for handling incoming TFTP requests. 

118 """ 

119 

120 def __init__(self, server: TftpServer): 

121 self.server = server 

122 self.transport: Optional[asyncio.DatagramTransport] = None 

123 self.logger = server.logger.getChild(self.__class__.__name__) 

124 

125 def connection_made(self, transport: asyncio.DatagramTransport): 

126 self.transport = transport 

127 self.logger.debug("Server protocol connection established") 

128 

129 def connection_lost(self, exc: Optional[Exception]): 

130 self.logger.info("TFTP server protocol connection lost") 

131 

132 def datagram_received(self, data: bytes, addr: Tuple[str, int]): 

133 self.logger.debug(f"Received datagram from {addr}") 

134 if len(data) < 4: 

135 self.logger.warning(f"Received malformed packet from {addr}") 

136 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Malformed packet") 

137 return 

138 

139 if self.server.shutdown_event.is_set(): 

140 self.logger.debug(f"Ignoring packet from {addr} - server is shutting down") 

141 return 

142 

143 try: 

144 opcode = Opcode(int.from_bytes(data[0:2], "big")) 

145 except ValueError: 

146 self.logger.error(f"Unknown opcode from {addr}") 

147 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") 

148 return 

149 

150 self.logger.debug(f"Received opcode {opcode.name} from {addr}") 

151 

152 if opcode == Opcode.RRQ: 

153 asyncio.create_task(self._handle_read_request(data, addr)) 

154 else: 

155 self.logger.warning(f"Unsupported opcode {opcode} from {addr}") 

156 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported operation") 

157 

158 async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]): 

159 try: 

160 filename, mode, options = self._parse_request(data) 

161 self.logger.info(f"RRQ from {addr}: '{filename}' in mode '{mode}' with options {options}") 

162 

163 if not self._validate_mode(mode, addr): 

164 return 

165 

166 resolved_path = await self._resolve_and_validate_path(filename, addr) 

167 if not resolved_path: 

168 return 

169 

170 negotiated_options, blksize, timeout = self._negotiate_options(options) 

171 self.logger.info(f"Negotiated options: {negotiated_options}") 

172 await self._start_transfer(resolved_path, addr, blksize, timeout, negotiated_options) 

173 except Exception as e: 

174 self.logger.error(f"Error handling RRQ from {addr}: {e}") 

175 self._send_error(addr, TftpErrorCode.NOT_DEFINED, str(e)) 

176 

177 def _send_oack(self, addr: Tuple[str, int], options: dict): 

178 """Send Option Acknowledgment (OACK) packet.""" 

179 oack_data = Opcode.OACK.to_bytes(2, "big") 

180 for opt_name, opt_value in options.items(): 

181 oack_data += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8") 

182 

183 if self.transport: 

184 self.transport.sendto(oack_data, addr) 

185 self.logger.debug(f"Sent OACK to {addr} with options {options}") 

186 

187 def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): 

188 error_packet = ( 

189 Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" 

190 ) 

191 if self.transport: 

192 self.transport.sendto(error_packet, addr) 

193 self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}") 

194 

195 def _parse_request(self, data: bytes) -> Tuple[str, str, dict]: 

196 parts = data[2:].split(b"\x00") 

197 if len(parts) < 2: 

198 raise ValueError("Invalid RRQ format") 

199 

200 filename = parts[0].decode("utf-8") 

201 if len(filename) > 255: # RFC 1350 doesn't specify a limit 

202 raise ValueError("Filename too long") 

203 if not all(c.isprintable() and c not in '<>:"/\\|?*' for c in filename): 

204 raise ValueError("Invalid characters in filename") 

205 if "\x00" in filename: 

206 raise ValueError("Null byte in filename") 

207 mode = parts[1].decode("utf-8").lower() 

208 options = self._parse_options(parts[2:]) 

209 

210 return filename, mode, options 

211 

212 def _parse_options(self, option_parts: list) -> dict: 

213 options = {} 

214 i = 0 

215 while i < len(option_parts) - 1: 

216 try: 

217 opt_name = option_parts[i].decode("utf-8").lower() 

218 opt_value = option_parts[i + 1].decode("utf-8") 

219 options[opt_name] = opt_value 

220 i += 2 

221 except Exception: 

222 break 

223 return options 

224 

225 def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool: 

226 if mode not in ("netascii", "octet"): 

227 self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}") 

228 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported transfer mode") 

229 return False 

230 return True 

231 

232 async def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[str]: 

233 try: 

234 stat = await self.server.operator.stat(filename) 

235 except FileNotFoundError: 

236 self.logger.error(f"File not found: {filename}") 

237 self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") 

238 return None 

239 

240 if not stat.mode.is_file(): 

241 self.logger.error(f"Not a file: {filename}") 

242 self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "Not a file") 

243 return None 

244 

245 return filename 

246 

247 def _negotiate_block_size(self, requested_blksize: Optional[str]) -> int: 

248 if requested_blksize is None: 

249 return self.server.block_size 

250 

251 try: 

252 blksize = int(requested_blksize) 

253 if 512 <= blksize <= 65464: 

254 return blksize 

255 else: 

256 self.logger.warning( 

257 f"Requested block size {blksize} out of range (512-65464), using default: {self.server.block_size}" 

258 ) 

259 return self.server.block_size 

260 except ValueError: 

261 self.logger.warning( 

262 f"Invalid block size value '{requested_blksize}', using default: {self.server.block_size}" 

263 ) 

264 return self.server.block_size 

265 

266 def _negotiate_timeout(self, requested_timeout: Optional[str]) -> float: 

267 if requested_timeout is None: 

268 return self.server.timeout 

269 

270 try: 

271 timeout = int(requested_timeout) 

272 if 1 <= timeout <= 255: 

273 return float(timeout) 

274 else: 

275 self.logger.warning( 

276 f"Timeout value {timeout} out of range (1-255), using default: {self.server.timeout}" 

277 ) 

278 return self.server.timeout 

279 except ValueError: 

280 self.logger.warning(f"Invalid timeout value '{requested_timeout}', using default: {self.server.timeout}") 

281 return self.server.timeout 

282 

283 def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: 

284 negotiated = {} 

285 blksize = self.server.block_size 

286 timeout = self.server.timeout 

287 

288 if "blksize" in options: 

289 requested = options["blksize"] 

290 blksize = self._negotiate_block_size(requested) 

291 negotiated["blksize"] = blksize 

292 

293 if "timeout" in options: 

294 requested = options["timeout"] 

295 timeout = self._negotiate_timeout(requested) 

296 negotiated["timeout"] = int(timeout) 

297 

298 return negotiated, blksize, timeout 

299 

300 async def _start_transfer( 

301 self, filepath: str, addr: Tuple[str, int], blksize: int, timeout: float, negotiated_options: dict 

302 ): 

303 transfer = TftpReadTransfer( 

304 server=self.server, 

305 filepath=filepath, 

306 client_addr=addr, 

307 block_size=blksize, 

308 timeout=timeout, 

309 retries=self.server.retries, 

310 negotiated_options=negotiated_options, 

311 ) 

312 self.server.register_transfer(transfer) 

313 asyncio.create_task(transfer.start()) 

314 

315 

316def is_subpath(path: pathlib.Path, root: pathlib.Path) -> bool: 

317 try: 

318 path.relative_to(root) 

319 return True 

320 except ValueError: 

321 return False 

322 

323 

324class TftpTransfer: 

325 """ 

326 Base class for TFTP transfers. 

327 """ 

328 

329 def __init__( 

330 self, 

331 server: TftpServer, 

332 filepath: pathlib.Path, 

333 client_addr: Tuple[str, int], 

334 block_size: int, 

335 timeout: float, 

336 retries: int, 

337 ): 

338 self.server = server 

339 self.filepath = filepath 

340 self.client_addr = client_addr 

341 self.block_size = block_size 

342 self.timeout = timeout 

343 self.retries = retries 

344 self.transport: Optional[asyncio.DatagramTransport] = None 

345 self.protocol: Optional["TftpTransferProtocol"] = None 

346 self.cleanup_task: Optional[asyncio.Task] = None 

347 self.logger = server.logger.getChild(self.__class__.__name__) 

348 

349 async def start(self): 

350 """Start the transfer.""" 

351 raise NotImplementedError 

352 

353 async def cleanup(self): 

354 """Clean up transfer resources.""" 

355 self.logger.info(f"Cleaning up transfer for {self.client_addr}") 

356 if self.transport: 

357 self.transport.close() 

358 self.transport = None 

359 self.server.unregister_transfer(self) 

360 

361 

362class TftpReadTransfer(TftpTransfer): 

363 def __init__( 

364 self, 

365 server: TftpServer, 

366 filepath: str, 

367 client_addr: Tuple[str, int], 

368 block_size: int, 

369 timeout: float, 

370 retries: int, 

371 negotiated_options: Optional[dict] = None, 

372 ): 

373 super().__init__( 

374 server=server, 

375 filepath=filepath, 

376 client_addr=client_addr, 

377 block_size=block_size, 

378 timeout=timeout, 

379 retries=retries, 

380 ) 

381 self.block_num = 0 

382 self.ack_received = asyncio.Event() 

383 self.last_ack = 0 

384 self.oack_confirmed = False 

385 self.negotiated_options = negotiated_options 

386 self.current_packet: Optional[bytes] = None 

387 

388 async def start(self): 

389 self.logger.info(f"Starting read transfer of '{self.filepath}' to {self.client_addr}") 

390 

391 if not await self._initialize_transfer(): 

392 return 

393 

394 try: 

395 # if no options were negotiated, we can start sending data immediately 

396 if not self.negotiated_options: 

397 self.oack_confirmed = True 

398 

399 await self._perform_transfer() 

400 except Exception as e: 

401 self.logger.error(f"Error during read transfer: {e}") 

402 finally: 

403 await self.cleanup() 

404 

405 async def _initialize_transfer(self) -> bool: 

406 loop = asyncio.get_running_loop() 

407 

408 self.transport, self.protocol = await loop.create_datagram_endpoint( 

409 lambda: TftpTransferProtocol(self), local_addr=("0.0.0.0", 0), remote_addr=self.client_addr 

410 ) 

411 local_addr = self.transport.get_extra_info("sockname") 

412 self.logger.debug(f"Transfer bound to local {local_addr}") 

413 

414 # Only send OACK if we have non-default options to negotiate 

415 if self.negotiated_options and ( 

416 self.negotiated_options["blksize"] != 512 or self.negotiated_options["timeout"] != self.server.timeout 

417 ): 

418 oack_packet = self._create_oack_packet() 

419 if not await self._send_with_retries(oack_packet, is_oack=True): 

420 self.logger.error("Failed to get acknowledgment for OACK") 

421 return False 

422 

423 self.block_num = 1 

424 return True 

425 

426 async def _perform_transfer(self): 

427 async with await self.server.operator.open(self.filepath, "rb") as f: 

428 while True: 

429 if self.server.shutdown_event.is_set(): 

430 self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") 

431 break 

432 

433 # read a full block or until EOF 

434 data = bytearray() 

435 while len(data) < self.block_size: 

436 chunk = await f.read(size=self.block_size - len(data)) 

437 if not chunk: # EOF reached 

438 break 

439 data.extend(chunk) 

440 

441 # send the data (converted to bytes) 

442 if not await self._handle_data_block(bytes(data)): 

443 break 

444 

445 async def _handle_data_block(self, data: bytes) -> bool: 

446 """ 

447 Handle sending a block of data to the client. 

448 Returns False if transfer should stop, True if it should continue. 

449 """ 

450 if not data and self.block_num == 1: 

451 # Empty file case 

452 packet = self._create_data_packet(b"") 

453 await self._send_with_retries(packet) 

454 return False 

455 elif data: 

456 packet = self._create_data_packet(data) 

457 success = await self._send_with_retries(packet) 

458 if not success: 

459 self.logger.error(f"Failed to send block {self.block_num} to {self.client_addr}") 

460 return False 

461 

462 self.logger.debug(f"Block {self.block_num} sent successfully") 

463 self.block_num += 1 

464 

465 # wrap block number around if it exceeds 16 bits 

466 self.block_num %= 65536 

467 

468 if len(data) < self.block_size: 

469 self.logger.info(f"Final block {self.block_num - 1} sent") 

470 return False 

471 return True 

472 else: 

473 # EOF reached 

474 packet = self._create_data_packet(b"") 

475 success = await self._send_with_retries(packet) 

476 if not success: 

477 self.logger.error(f"Failed to send final block {self.block_num}") 

478 else: 

479 self.logger.info(f"Transfer complete, final block {self.block_num}") 

480 return False 

481 

482 def _create_oack_packet(self) -> bytes: 

483 packet = Opcode.OACK.to_bytes(2, "big") 

484 for opt_name, opt_value in self.negotiated_options.items(): 

485 packet += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8") 

486 return packet 

487 

488 def _create_data_packet(self, data: bytes) -> bytes: 

489 return Opcode.DATA.to_bytes(2, "big") + self.block_num.to_bytes(2, "big") + data 

490 

491 def _send_packet(self, packet: bytes): 

492 self.transport.sendto(packet) 

493 if packet[0:2] == Opcode.DATA.to_bytes(2, "big"): 

494 block = int.from_bytes(packet[2:4], "big") 

495 data_length = len(packet) - 4 

496 self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}") 

497 elif packet[0:2] == Opcode.OACK.to_bytes(2, "big"): 

498 self.logger.debug(f"Sent OACK to {self.client_addr}") 

499 

500 async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool: 

501 self.current_packet = packet 

502 expected_block = 0 if is_oack else self.block_num 

503 

504 for attempt in range(1, self.retries + 1): 

505 try: 

506 self._send_packet(packet) 

507 self.logger.debug( 

508 f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, waiting for ACK (Attempt {attempt})" 

509 ) 

510 self.ack_received.clear() 

511 await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout) 

512 

513 if self.last_ack == expected_block: 

514 self.logger.debug(f"ACK received for block {expected_block}") 

515 return True 

516 else: 

517 self.logger.warning(f"Received wrong ACK: expected {expected_block}, got {self.last_ack}") 

518 

519 except asyncio.TimeoutError: 

520 self.logger.warning(f"Timeout waiting for ACK of block {expected_block} (Attempt {attempt})") 

521 

522 return False 

523 

524 def handle_ack(self, block_num: int): 

525 self.logger.debug(f"Received ACK for block {block_num} from {self.client_addr}") 

526 

527 # special handling for OACK acknowledgment 

528 if not self.oack_confirmed and self.negotiated_options and block_num == 0: 

529 self.oack_confirmed = True 

530 self.last_ack = block_num 

531 self.ack_received.set() 

532 return 

533 

534 if block_num == self.block_num: 

535 self.last_ack = block_num 

536 self.ack_received.set() 

537 elif block_num == self.block_num - 1: 

538 self.logger.warning(f"Duplicate ACK for block {block_num} received, resending block {self.block_num}") 

539 self.transport.sendto(self.current_packet) 

540 else: 

541 self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}") 

542 

543 

544class TftpTransferProtocol(asyncio.DatagramProtocol): 

545 """ 

546 Protocol for handling ACKs during a TFTP transfer. 

547 """ 

548 

549 def __init__(self, transfer: TftpReadTransfer): 

550 self.transfer = transfer 

551 self.logger = transfer.logger.getChild(self.__class__.__name__) 

552 

553 def connection_made(self, transport: asyncio.DatagramTransport): 

554 self.transfer.transport = transport 

555 local_addr = transport.get_extra_info("sockname") 

556 self.logger.debug(f"Transfer protocol connection established on {local_addr} for {self.transfer.client_addr}") 

557 

558 def datagram_received(self, data: bytes, addr: Tuple[str, int]): 

559 self.logger.debug(f"Received datagram from {addr}") 

560 if addr != self.transfer.client_addr: 

561 self.logger.warning(f"Ignoring packet from unknown source {addr}") 

562 return 

563 

564 if len(data) < 4: 

565 self.logger.warning(f"Received malformed packet from {addr}") 

566 return 

567 

568 try: 

569 opcode = Opcode(int.from_bytes(data[0:2], "big")) 

570 except ValueError: 

571 self.logger.error(f"Unknown opcode from {addr}") 

572 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") 

573 return 

574 

575 if opcode == Opcode.ACK: 

576 block_num = int.from_bytes(data[2:4], "big") 

577 self.logger.debug(f"Received ACK for block {block_num} from {addr}") 

578 self.transfer.handle_ack(block_num) 

579 else: 

580 self.logger.warning(f"Unexpected opcode {opcode} from {addr}") 

581 self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unexpected opcode") 

582 

583 def error_received(self, exc): 

584 self.logger.error(f"Error received: {exc}") 

585 

586 def connection_lost(self, exc): 

587 self.logger.debug(f"Connection closed for transfer to {self.transfer.client_addr}") 

588 

589 def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): 

590 error_packet = ( 

591 Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" 

592 ) 

593 if self.transfer.transport: 

594 self.transfer.transport.sendto(error_packet) 

595 self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}")