Source code for runtimepy.net.mixin
"""
Various networking-related class utilities.
"""
from __future__ import annotations
# built-in
import asyncio as _asyncio
from socket import SocketType as _SocketType
from typing import Callable
from typing import Optional as _Optional
from typing import cast as _cast
# third-party
from vcorelib.io import BinaryMessage
# internal
from runtimepy.net import IpHost as _IpHost
from runtimepy.net import normalize_host as _normalize_host
from runtimepy.net.mtu import ETHERNET_MTU, UDP_DEFAULT_MTU, host_discover_mtu
[docs]
class BinaryMessageQueueMixin:
"""A mixin for adding a 'queue' attribute."""
def __init__(self) -> None:
"""Initialize this protocol."""
self.queue: _asyncio.Queue[BinaryMessage] = _asyncio.Queue()
[docs]
class TransportMixin:
"""A class simplifying evaluation of local and remote addresses."""
_transport: _asyncio.BaseTransport
remote_address: _Optional[_IpHost]
[docs]
def set_transport(self, transport: _asyncio.BaseTransport) -> None:
"""Set the transport for this instance."""
self._transport = transport
# Get the local address of this connection.
self.local_address = _normalize_host(
*self._transport.get_extra_info("sockname")
)
# A bug in the Windows implementation causes the 'addr' argument of
# sendto to be required. Save a copy of the remote address (may be
# None).
self.remote_address = self._remote_address()
def __init__(self, transport: _asyncio.BaseTransport) -> None:
"""Initialize this instance."""
self.set_transport(transport)
[docs]
def mtu(
self,
probe_size: int = UDP_DEFAULT_MTU,
fallback: int = ETHERNET_MTU,
probe_create: Callable[[int], bytes] = bytes,
) -> int:
"""
Get a maximum transmission unit for this connection. Underlying sockets
must be connected for this to work.
"""
assert (
self.remote_address is not None
), "Not connected! Can't probe MTU."
return host_discover_mtu(
self.local_address,
self.remote_address,
probe_size,
fallback,
kind=self.socket.type,
probe_create=probe_create,
)
@property
def socket(self) -> _SocketType:
"""Get this instance's underlying socket."""
return _cast(_SocketType, self._transport.get_extra_info("socket"))
def _remote_address(self) -> _Optional[_IpHost]:
"""Get a possible remote address for this connection."""
result = self._transport.get_extra_info("peername")
addr = None
if result is not None:
addr = _normalize_host(*result)
return addr
[docs]
def logger_name(self, prefix: str = "") -> str:
"""Get a logger name for this connection."""
name = prefix + str(self.local_address)
if self.remote_address is not None:
name += f" -> {self.remote_address}"
return name