"""
A module implementing various networking utilities.
"""
# built-in
from contextlib import suppress as _suppress
from functools import cache
import ipaddress
import socket as _socket
from typing import Awaitable, NamedTuple, Optional, TypeVar
from typing import Union as _Union
# third-party
from vcorelib.logging import LoggerType
[docs]
class IPv4Host(NamedTuple):
"""See: https://docs.python.org/3/library/socket.html#socket-families."""
name: str = ""
port: int = 0
@property
def family(self) -> int:
"""Address family constant."""
return _socket.AF_INET
[docs]
def zero_port(self) -> "IPv4Host":
"""Get a zeroed-out port instance."""
return self if self.port == 0 else IPv4Host(self.name, 0)
@property
def hostname(self) -> str:
"""Get a hostname for this instance."""
return hostname(self.name)
@property
def address(self) -> ipaddress.IPv4Address:
"""Get an address object for this hostname."""
return ipaddress.IPv4Address(self.address_str)
@property
def address_str(self) -> str:
"""Get an address string for this host."""
return address_str(
self.name, family=self.family, fallback_host="0.0.0.0"
)
@property
def address_str_tuple(self) -> tuple[str, int]:
"""Get a string-address tuple for this instance."""
return (self.address_str, self.port)
def __str__(self) -> str:
"""Get this host as a string."""
return hostname_port(self.name, self.port)
def __hash__(self) -> int:
"""Get a hash for this instance."""
return hash(str(self))
[docs]
class IPv6Host(NamedTuple):
"""See: https://docs.python.org/3/library/socket.html#socket-families."""
name: str = ""
port: int = 0
flowinfo: int = 0
scope_id: int = 0
@property
def hostname(self) -> str:
"""Get a hostname for this instance."""
return hostname(self.name)
@property
def address(self) -> ipaddress.IPv6Address:
"""Get an address object for this hostname."""
return ipaddress.IPv6Address(self.address_str)
[docs]
def zero_port(self) -> "IPv6Host":
"""Get a zeroed-out port instance."""
return (
self
if self.port == 0
else IPv6Host(self.name, 0, self.flowinfo, self.scope_id)
)
@property
def family(self) -> int:
"""Address family constant."""
return _socket.AF_INET6
@property
def address_str(self) -> str:
"""Get an address string for this host."""
return address_str(self.name, family=self.family, fallback_host="::")
@property
def address_str_tuple(self) -> tuple[str, int, int, int]:
"""Get a string-address tuple for this instance."""
return (self.address_str, self.port, self.flowinfo, self.scope_id)
def __str__(self) -> str:
"""Get this host as a string."""
return hostname_port(self.name, self.port)
def __hash__(self) -> int:
"""Get a hash for this instance."""
return hash(str(self))
IpHost = _Union[IPv4Host, IPv6Host]
IpHostlike = _Union[str, int, IpHost, None]
IpHostTuplelike = _Union[IpHost, tuple[str, int], tuple[str, int, int, int]]
[docs]
def normalize_host(
*args: IpHostlike, default: type[IpHost] = IPv4Host
) -> IpHost:
"""Get a host object from caller parameters."""
# Default.
if not args or args[0] is None:
return default()
if isinstance(args[0], (IPv4Host, IPv6Host)):
return args[0]
if len([*args]) <= 2:
return IPv4Host(*args) # type: ignore
return IPv6Host(*args) # type: ignore
USE_FQDN = {"::", "0.0.0.0"}
[docs]
@cache
def hostname(ip_address: str) -> str:
"""
Attempt to get a string hostname for a string IP address argument that
'gethostbyaddr' accepts. Otherwise return the original string
"""
result = ip_address
if ip_address in USE_FQDN:
result = _socket.getfqdn()
else:
with _suppress(_socket.herror, OSError):
result = _socket.gethostbyaddr(ip_address)[0]
return result
[docs]
@cache
def address_str(name: str, fallback_host: str = "localhost", **kwargs) -> str:
"""Get an IP address string for a given name."""
return _socket.getaddrinfo( # type: ignore
name if name else fallback_host, None, **kwargs
)[0][4][0]
[docs]
@cache
def hostname_port(ip_address: str, port: int) -> str:
"""Get a hostname string with a port appended."""
return f"{hostname(ip_address)}:{port}"
[docs]
def sockname(sock: _socket.SocketType) -> IpHost:
"""Get address information from a socket."""
return normalize_host(*sock.getsockname())
[docs]
def get_free_socket(
local: IpHost = None,
kind: int = _socket.SOCK_STREAM,
reuse: bool = False,
) -> _socket.SocketType:
"""Attempt to get an available socket."""
local = normalize_host(local)
sock = _socket.socket(local.family, kind)
if reuse:
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1)
sock.bind(local.address_str_tuple)
return sock
[docs]
def get_free_socket_name(
local: IpHost = None, kind: int = _socket.SOCK_STREAM
) -> IpHost:
"""
Create a socket to determine an arbitrary port number that's available.
There is an inherent race condition using this strategy.
"""
sock = get_free_socket(local=local, reuse=True, kind=kind)
name = sockname(sock)
sock.close()
return name
T = TypeVar("T")
[docs]
async def try_log_connection_error(
future: Awaitable[T], message: str, logger: LoggerType = None
) -> Optional[T]:
"""
Attempt to resolve a future but log any connection-related exceptions.
"""
result = None
try:
result = await future
except ConnectionError as exc:
if logger:
logger.exception(message, exc_info=exc)
return result