Source code for runtimepy.net.mtu

"""
A module implementing utilities for calculating maximum transmission-unit
sizes.
"""

# built-in
from contextlib import suppress
from enum import IntEnum
from functools import cache
import logging
import socket
from typing import Callable

# internal
from runtimepy.net.util import (
    IpHost,
    IpHostlike,
    get_free_socket,
    normalize_host,
)

LOG = logging.getLogger(__name__)


[docs] class SocketConstants(IntEnum): """Some platform definitions necessary for mtu discovery.""" IP_MTU = 14 IP_MTU_DISCOVER = 10 IP_PMTUDISC_DO = 2
ETHERNET_MTU = 1500 IP_HEADER_SIZE = 60 UDP_HEADER_SIZE = IP_HEADER_SIZE + 8 UDP_DEFAULT_MTU = ETHERNET_MTU - UDP_HEADER_SIZE
[docs] def socket_discover_mtu( sock: socket.SocketType, probe_size: int, fallback: int, probe_create: Callable[[int], bytes] = bytes, ) -> int: """ Send a large frame and indicate that we want to perform mtu discovery, and not fragment any frames. """ orig_val = None # see ip(7), force the don't-fragment flag and perform mtu discovery # such that the socket object can be queried for actual mtu upon error # Suppress platform incompatibility errors. with suppress(OSError): orig_val = sock.getsockopt( socket.IPPROTO_IP, SocketConstants.IP_MTU_DISCOVER ) with suppress(OSError): sock.setsockopt( socket.IPPROTO_IP, SocketConstants.IP_MTU_DISCOVER, SocketConstants.IP_PMTUDISC_DO, ) try: count = sock.send(probe_create(probe_size)) LOG.info("mtu probe successfully sent %d bytes", count) except OSError as exc: LOG.exception( "Error sending %d-byte MTU probe payload:", probe_size, exc_info=exc, ) # Restore the original value. if orig_val is not None: with suppress(OSError): sock.setsockopt( socket.IPPROTO_IP, SocketConstants.IP_MTU_DISCOVER, orig_val ) result = 0 # Suppress platform incompatibility errors. with suppress(OSError): result = sock.getsockopt(socket.IPPROTO_IP, SocketConstants.IP_MTU) return result if result else fallback
[docs] @cache def host_discover_mtu( local: IpHost, destination: IpHost, probe_size: int, fallback: int, kind: int = socket.SOCK_DGRAM, probe_create: Callable[[int], bytes] = bytes, ) -> int: """Perform MTU discovery given a local and remote host plus probe size.""" sock = get_free_socket(local=local.zero_port(), kind=kind) sock.connect(destination.address_str_tuple) result = socket_discover_mtu( sock, probe_size, fallback, probe_create=probe_create ) sock.close() return result
[docs] def discover_mtu( *destination: IpHostlike, local: IpHost = None, probe_size: int = UDP_DEFAULT_MTU, fallback: int = ETHERNET_MTU, kind: int = socket.SOCK_DGRAM, probe_create: Callable[[int], bytes] = bytes, ) -> int: """ Determine the maximum transmission unit for an IPv4 payload to a provided host. """ dest = normalize_host(*destination) local = normalize_host(local, default=type(dest)) result = host_discover_mtu( local, dest, probe_size, fallback, kind=kind, probe_create=probe_create ) LOG.info( "Discovered MTU to (%s -> %s) is %d (probe size: %d).", local, dest, result, probe_size, ) return result