"""
A module implementing a base tftp (RFC 1350) connection interface.
"""
# built-in
import asyncio
from contextlib import AsyncExitStack
from io import BytesIO
import logging
from pathlib import Path
from typing import BinaryIO, Callable
# third-party
from vcorelib.io import BinaryMessage
from vcorelib.math import metrics_time_ns
# internal
from runtimepy.net import IpHost
from runtimepy.net.udp.connection import UdpConnection
from runtimepy.net.udp.tftp.endpoint import TftpEndpoint
from runtimepy.net.udp.tftp.enums import (
DEFAULT_MODE,
TftpErrorCode,
TftpOpCode,
encode_filename_mode,
parse_filename_mode,
)
from runtimepy.net.util import IpHostTuplelike, normalize_host
from runtimepy.primitives import Double, Uint16
REEMIT_PERIOD_S = 0.20
DEFAULT_TIMEOUT_S = 1.0
TftpErrorHandler = Callable[[TftpErrorCode, str, tuple[str, int]], None]
[docs]
class BaseTftpConnection(UdpConnection):
"""A class implementing a basic tftp interface."""
log_alias = "TFTP"
_path: Path
default_auto_restart = True
should_connect = False
[docs]
def set_root(self, path: Path) -> None:
"""Set a new root path for this instance."""
self._path = path
for endpoint in self._endpoints.values():
endpoint.set_root(self._path)
self.logger.info("Set root directory to '%s'.", self._path)
@property
def path(self) -> Path:
"""Get this connection's root path."""
return self._path
[docs]
def init(self) -> None:
"""Initialize this instance."""
super().init()
# Path to serve files from.
self._path = Path()
TftpOpCode.register_enum(self.env.enums)
TftpErrorCode.register_enum(self.env.enums)
self.opcode = Uint16(time_source=metrics_time_ns)
self.env.channel("opcode", self.opcode, enum="TftpOpCode")
self.block_number = Uint16(time_source=metrics_time_ns)
self.env.channel("block_number", self.block_number)
self.error_code = Uint16(time_source=metrics_time_ns)
self.env.channel("error_code", self.error_code, enum="TftpErrorCode")
self.error_handlers: list[TftpErrorHandler] = []
self.endpoint_period = Double(value=REEMIT_PERIOD_S)
self.env.channel(
"reemit_period_s",
self.endpoint_period,
commandable=True,
description="Time between packet re-transmissions.",
)
self.endpoint_timeout = Double(value=DEFAULT_TIMEOUT_S)
self.env.channel(
"timeout_s",
self.endpoint_timeout,
commandable=True,
description="A timeout used for each step.",
)
# Message parsers.
self.handlers = {
TftpOpCode.RRQ.value: self._handle_rrq,
TftpOpCode.WRQ.value: self._handle_wrq,
TftpOpCode.DATA.value: self._handle_data,
TftpOpCode.ACK.value: self._handle_ack,
TftpOpCode.ERROR.value: self._handle_error,
}
def data_sender(
block: int, data: bytes, addr: IpHostTuplelike
) -> None:
"""Send data via this connection instance."""
self.send_data(block, data, addr=addr)
self.data_sender = data_sender
def ack_sender(block: int, addr: IpHostTuplelike) -> None:
"""Send an ack via this connection."""
self.send_ack(block=block, addr=addr)
self.ack_sender = ack_sender
def error_sender(
error_code: TftpErrorCode,
message: str,
addr: IpHostTuplelike,
) -> None:
"""Sen an error via this connection."""
self.send_error(error_code, message, addr=addr)
self.error_sender = error_sender
self._endpoints: dict[IpHost, TftpEndpoint] = {}
self._awaiting_first_ack: dict[str, TftpEndpoint] = {}
self._awaiting_first_block: dict[str, TftpEndpoint] = {}
# self._self = self.endpoint(self.local_address)
async def _await_first_ack(
self,
stack: AsyncExitStack,
addr: IpHostTuplelike = None,
) -> tuple[TftpEndpoint, asyncio.Event]:
"""Set up an endpoint to wait for an initial ack from a server."""
endpoint = self.endpoint(addr)
await stack.enter_async_context(endpoint.lock)
event = asyncio.Event()
endpoint.awaiting_acks[0] = event
self._awaiting_first_ack[endpoint.addr.hostname] = endpoint
return endpoint, event
async def _await_first_block(
self,
stack: AsyncExitStack,
addr: IpHostTuplelike = None,
) -> tuple[TftpEndpoint, asyncio.Event]:
"""Set up an endpoint to wait for an initial block from a server."""
endpoint = self.endpoint(addr)
await stack.enter_async_context(endpoint.lock)
event = asyncio.Event()
endpoint.awaiting_blocks[1] = event
self._awaiting_first_block[endpoint.addr.hostname] = endpoint
return endpoint, event
[docs]
def endpoint(self, addr: IpHostTuplelike = None) -> TftpEndpoint:
"""Lookup an endpoint instance from an address."""
if addr is None:
addr = self.remote_address
assert addr is not None
addr = normalize_host(*addr)
if addr not in self._endpoints:
self._endpoints[addr] = TftpEndpoint(
self._path,
self.logger,
addr,
self.data_sender,
self.ack_sender,
self.error_sender,
self.endpoint_period,
self.endpoint_timeout,
)
return self._endpoints[addr]
[docs]
def send_rrq(
self,
filename: str,
mode: str = DEFAULT_MODE,
addr: IpHostTuplelike = None,
) -> None:
"""Send a read request."""
self._send_message(
TftpOpCode.RRQ, encode_filename_mode(filename, mode), addr=addr
)
[docs]
def send_wrq(
self,
filename: str,
mode: str = DEFAULT_MODE,
addr: IpHostTuplelike = None,
) -> None:
"""Send a write request."""
self._send_message(
TftpOpCode.WRQ, encode_filename_mode(filename, mode), addr=addr
)
[docs]
def send_data(
self,
block: int,
data: bytes,
addr: IpHostTuplelike = None,
) -> None:
"""Send a data message."""
self.block_number.value = block
self._send_message(
TftpOpCode.DATA, bytes(self.block_number) + data, addr=addr
)
[docs]
def send_ack(self, block: int = 0, addr: IpHostTuplelike = None) -> None:
"""Send a data message."""
self.block_number.value = block
self._send_message(TftpOpCode.ACK, bytes(self.block_number), addr=addr)
[docs]
def send_error(
self,
error_code: TftpErrorCode,
message: str,
addr: IpHostTuplelike = None,
) -> None:
"""Send a data message."""
with BytesIO() as stream:
self.error_code.value = error_code.value
self.error_code.to_stream(stream)
stream.write(message.encode())
stream.write(b"\x00")
self._send_message(TftpOpCode.ERROR, stream.getvalue(), addr=addr)
# Log errors sent.
endpoint = self.endpoint(addr)
self.governed_log(
endpoint.log_limiter,
"Sent error '%s: %s' to %s.",
error_code.name,
message,
endpoint,
level=logging.WARNING,
)
def _send_message(
self,
opcode: TftpOpCode,
data: bytes,
addr: IpHostTuplelike = None,
) -> None:
"""Send a tftp message."""
with BytesIO() as stream:
# Set opcode.
self.opcode.value = opcode.value
self.opcode.to_stream(stream)
# Encode message.
stream.write(data)
self.sendto(stream.getvalue(), addr=addr)
async def _handle_rrq(
self, stream: BinaryIO, addr: tuple[str, int]
) -> None:
"""Handle a read request."""
task = self.endpoint(addr).handle_read_request(
*parse_filename_mode(stream)
)
if task is not None:
self._conn_tasks.append(task)
async def _handle_wrq(
self, stream: BinaryIO, addr: tuple[str, int]
) -> None:
"""Handle a write request."""
task = self.endpoint(addr).handle_write_request(
*parse_filename_mode(stream)
)
if task is not None:
self._conn_tasks.append(task)
async def _handle_data(
self, stream: BinaryIO, addr: tuple[str, int]
) -> None:
"""Handle a data message."""
endpoint = self.endpoint(addr)
block = self._read_block_number(stream)
# Check if we're currently waiting for an initial block.
hostname = endpoint.addr.hostname
if block == 1 and hostname in self._awaiting_first_block:
to_update = self._awaiting_first_block[hostname]
del self._awaiting_first_block[hostname]
self._endpoints[endpoint.addr] = to_update
endpoint = to_update.update_from_other(endpoint)
endpoint.handle_data(block, stream.read())
async def _handle_ack(
self, stream: BinaryIO, addr: tuple[str, int]
) -> None:
"""Handle an acknowledge message."""
endpoint = self.endpoint(addr)
block = self._read_block_number(stream)
# Check if we're currently waiting for an initial acknowledgement. This
# will come from the same host but a different port, so update
# references when this is detected.
hostname = endpoint.addr.hostname
if block == 0 and hostname in self._awaiting_first_ack:
to_update = self._awaiting_first_ack[hostname]
del self._awaiting_first_ack[hostname]
self._endpoints[endpoint.addr] = to_update
endpoint = to_update.update_from_other(endpoint)
endpoint.handle_ack(block)
def _read_block_number(self, stream: BinaryIO) -> int:
"""Read block number from the stream."""
return self.block_number.from_stream(stream)
async def _handle_error(
self, stream: BinaryIO, addr: tuple[str, int]
) -> None:
"""Handle an error message."""
# Update underlying primitive.
error_code = TftpErrorCode(self.error_code.from_stream(stream))
message = stream.read().decode()
# Call extra error handlers.
for handler in self.error_handlers:
handler(error_code, message, addr)
self.endpoint(addr).handle_error(error_code, message)
[docs]
async def process_datagram(
self, data: BinaryMessage, addr: tuple[str, int]
) -> bool:
"""Process a datagram."""
with BytesIO(data) as stream:
self.opcode.from_stream(stream)
code: int = self.opcode.value
if code in self.handlers:
await self.handlers[code](stream, addr)
else:
msg = f"Unknown opcode {code}"
self.send_error(TftpErrorCode.ILLEGAL_OPERATION, msg)
self.logger.error("%s from %s:%d.", msg, addr[0], addr[1])
return True