Source code for runtimepy.net.udp.tftp

"""
A module implementing a tftp (RFC 1350) interface.
"""

# built-in
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from os import fsync
from pathlib import Path
from typing import Any, AsyncIterator

# third-party
from vcorelib.asyncio.poll import repeat_until
from vcorelib.paths.context import PossiblePath, as_path, tempfile
from vcorelib.paths.hashing import file_md5_hex
from vcorelib.paths.info import FileInfo

# internal
from runtimepy.net import normalize_host
from runtimepy.net.udp.tftp.base import (
    DEFAULT_TIMEOUT_S,
    REEMIT_PERIOD_S,
    BaseTftpConnection,
    TftpErrorHandler,
)
from runtimepy.net.udp.tftp.enums import DEFAULT_MODE
from runtimepy.net.util import IpHostTuplelike


[docs] class TftpConnection(BaseTftpConnection): """A class implementing a basic tftp interface."""
[docs] async def request_read( self, destination: Path, filename: str, mode: str = DEFAULT_MODE, addr: IpHostTuplelike = None, ) -> bool: """Request a tftp read operation.""" end_of_data = False idx = 1 async with AsyncExitStack() as stack: # Claim read lock and ignore cancellation. stack.enter_context(suppress(asyncio.CancelledError)) endpoint, event = await self._await_first_block(stack, addr=addr) def ack_sender() -> None: """Send acks.""" endpoint.ack_sender(idx - 1, endpoint.addr) def send_rrq() -> None: """Send request""" self.send_rrq(filename, mode=mode, addr=addr) self.logger.info( "Requesting '%s' (%s) -> %s.", filename, mode, destination ) with self.log_time("Awaiting first data block", reminder=True): # Wait for first data block. if not await repeat_until( send_rrq, event, endpoint.period.value, endpoint.timeout.value, ): endpoint.awaiting_blocks.pop(idx, None) self.logger.error("Didn't receive any data block.") return False path_fd = stack.enter_context(destination.open("wb")) end_of_data = False def write_block() -> None: """Write block.""" # Write block. nonlocal idx data = endpoint.blocks[idx] path_fd.write(data) del endpoint.blocks[idx] # Compute if this is the end of the stream. nonlocal end_of_data end_of_data = len(data) < endpoint.max_block_size if not end_of_data: idx += 1 else: fsync(path_fd.fileno()) write_block() success = True while not end_of_data and success: event = asyncio.Event() endpoint.awaiting_blocks[idx] = event success = await repeat_until( ack_sender, event, endpoint.period.value, endpoint.timeout.value, ) if success: write_block() # Repeat last ack in the background. if end_of_data: self._conn_tasks.append( asyncio.create_task( repeat_until( # type: ignore ack_sender, asyncio.Event(), endpoint.period.value, endpoint.timeout.value, ) ) ) # Ensure at least one ack sends. await asyncio.sleep(0.01) self.logger.info( "Read %s (%s).", FileInfo.from_file(destination), "end of data" if end_of_data else "not end of data", ) return end_of_data
[docs] async def request_write( self, source: PossiblePath, filename: str, mode: str = DEFAULT_MODE, addr: IpHostTuplelike = None, verify: bool = True, ) -> bool: """Request a tftp write operation.""" result = False with as_path(source) as src: async with AsyncExitStack() as stack: # Claim write lock and ignore cancellation. stack.enter_context(suppress(asyncio.CancelledError)) # Set up first-ack handling. endpoint, event = await self._await_first_ack(stack, addr=addr) def send_wrq() -> None: """Send request.""" self.send_wrq(filename, mode=mode, addr=addr) # Wait for zeroeth ack. with self.log_time("Awaiting first ack", reminder=True): if not await repeat_until( send_wrq, event, endpoint.period.value, endpoint.timeout.value, ): endpoint.awaiting_acks.pop(0, None) return result result = await endpoint.serve_file(src) # Verify by reading back. if verify and result: with self.log_time("Verifying write via read", reminder=True): with tempfile() as tmp: result = await self.request_read( tmp, filename, mode=mode, addr=addr ) # Compare hashes. self.logger.info( "Reading '%s' %s.", filename, "succeeded" if result else "failed", ) if result: result = file_md5_hex(src) == file_md5_hex(tmp) self.logger.info( "MD5 sums %s", "matched." if result else "didn't match!", ) return result
TFTP_PORT = 69
[docs] @asynccontextmanager async def tftp( addr: IpHostTuplelike, process_kwargs: dict[str, Any] = None, connection_kwargs: dict[str, Any] = None, timeout_s: float = DEFAULT_TIMEOUT_S, reemit_period_s: float = REEMIT_PERIOD_S, error_handler: TftpErrorHandler = None, ) -> AsyncIterator[TftpConnection]: """Use a tftp connection as a managed context.""" if process_kwargs is None: process_kwargs = {} if connection_kwargs is None: connection_kwargs = {} addr = normalize_host(*addr) # Create and start connection. conn = await TftpConnection.create_connection( remote_addr=(addr.name, addr.port), **connection_kwargs ) # Add error handlers. if error_handler is not None: conn.error_handlers.append(error_handler) async with conn.process_then_disable(**process_kwargs): # Set parameters. conn.endpoint_timeout.value = timeout_s conn.endpoint_period.value = reemit_period_s yield conn
[docs] async def tftp_write( addr: IpHostTuplelike, source: PossiblePath, filename: str, mode: str = DEFAULT_MODE, verify: bool = True, process_kwargs: dict[str, Any] = None, connection_kwargs: dict[str, Any] = None, timeout_s: float = DEFAULT_TIMEOUT_S, reemit_period_s: float = REEMIT_PERIOD_S, error_handler: TftpErrorHandler = None, ) -> bool: """Attempt to perform a tftp write.""" async with tftp( addr, process_kwargs=process_kwargs, connection_kwargs=connection_kwargs, timeout_s=timeout_s, reemit_period_s=reemit_period_s, error_handler=error_handler, ) as conn: # Perform tftp interaction. result = await conn.request_write( source, filename, mode=mode, addr=addr, verify=verify ) return result
[docs] async def tftp_read( addr: IpHostTuplelike, destination: Path, filename: str, mode: str = DEFAULT_MODE, process_kwargs: dict[str, Any] = None, connection_kwargs: dict[str, Any] = None, timeout_s: float = DEFAULT_TIMEOUT_S, reemit_period_s: float = REEMIT_PERIOD_S, error_handler: TftpErrorHandler = None, ) -> bool: """Attempt to perform a tftp read.""" async with tftp( addr, process_kwargs=process_kwargs, connection_kwargs=connection_kwargs, timeout_s=timeout_s, reemit_period_s=reemit_period_s, error_handler=error_handler, ) as conn: result = await conn.request_read( destination, filename, mode=mode, addr=addr ) return result