"""
A module implementing a WebSocket connection interface.
"""
from __future__ import annotations
# built-in
import asyncio as _asyncio
from contextlib import AsyncExitStack as _AsyncExitStack
from contextlib import asynccontextmanager as _asynccontextmanager
from contextlib import suppress as _suppress
from logging import getLogger as _getLogger
from typing import Any as _Any
from typing import AsyncIterator as _AsyncIterator
from typing import Awaitable as _Awaitable
from typing import Callable as _Callable
from typing import Optional as _Optional
from typing import TypeVar as _TypeVar
from typing import Union as _Union
# third-party
from vcorelib.asyncio import log_exceptions as _log_exceptions
from vcorelib.io import BinaryMessage
import websockets
from websockets.asyncio.client import ClientConnection as _ClientConnection
from websockets.asyncio.server import Server as _Server
from websockets.asyncio.server import ServerConnection as _ServerConnection
from websockets.asyncio.server import serve as _serve
from websockets.exceptions import ConnectionClosed as _ConnectionClosed
# internal
from runtimepy.net import normalize_host as _normalize_host
from runtimepy.net import sockname as _sockname
from runtimepy.net.connection import Connection
from runtimepy.net.connection import EchoConnection as _EchoConnection
from runtimepy.net.connection import NullConnection as _NullConnection
from runtimepy.net.manager import ConnectionManager as _ConnectionManager
from runtimepy.net.ssl import handle_possible_ssl
T = _TypeVar("T", bound="WebsocketConnection")
ConnectionInit = _Callable[[T], _Awaitable[bool]]
V = _TypeVar("V")
LOG = _getLogger(__name__)
[docs]
async def websocket_connect(uri: str, **kwargs) -> _ClientConnection:
"""Attempt to connect a websocket interface."""
# Defaults.
kwargs.setdefault("use_ssl", uri.startswith("wss"))
return await getattr(websockets, "connect")( # type: ignore
uri, **handle_possible_ssl(**kwargs)
)
[docs]
class WebsocketConnection(Connection):
"""A simple websocket connection interface."""
def __init__(
self,
protocol: _Union[_ClientConnection, _ServerConnection],
**kwargs,
) -> None:
"""Initialize this connection."""
self.protocol = protocol
super().__init__(
_getLogger(
f"W {_normalize_host(*self.protocol.local_address)} -> "
f"{_normalize_host(*self.protocol.remote_address)}"
),
**kwargs,
)
# Store connection-instantiation arguments (for connection restarting).
self._uri: str = ""
self._conn_kwargs: dict[str, _Any] = {}
async def _handle_connection_closed(
self, task: _Awaitable[V]
) -> _Optional[V]:
"""A wrapper for handling connection close."""
result = None
try:
result = await task
except _ConnectionClosed:
self.disable("connection closed")
return result
async def _await_message(self) -> _Optional[_Union[BinaryMessage, str]]:
"""Await the next message. Return None on error or failure."""
data = await self._handle_connection_closed(self.protocol.recv())
if data is not None:
self.metrics.rx.increment(len(data))
return data
async def _send_text_message(self, data: str) -> None:
"""Send a text message."""
await self._handle_connection_closed(self.protocol.send(data))
self.metrics.tx.increment(len(data))
async def _send_binay_message(self, data: BinaryMessage) -> None:
"""Send a binary message."""
await self._handle_connection_closed(self.protocol.send(data))
self.metrics.tx.increment(len(data))
[docs]
async def close(self) -> None:
"""Close this connection."""
await self.protocol.close()
[docs]
@classmethod
async def create_connection(
cls: type[T], uri: str, markdown: str = None, **kwargs
) -> T:
"""Connect a client to an endpoint."""
inst = cls(await websocket_connect(uri, **kwargs), markdown=markdown)
# Stored for connection restart capability.
inst._uri = uri
inst._conn_kwargs = {**kwargs}
return inst
[docs]
async def restart(self) -> bool:
"""
Reset necessary underlying state for this connection to 'process'
again.
"""
result = False
if self._uri:
with _suppress(TimeoutError, OSError):
self.protocol = await websocket_connect(
self._uri, **self._conn_kwargs
)
result = True
return result
[docs]
@classmethod
@_asynccontextmanager
async def client(
cls: type[T], uri: str, markdown: str = None, **kwargs
) -> _AsyncIterator[T]:
"""A wrapper for connecting a client."""
kwargs.setdefault("use_ssl", uri.startswith("wss"))
async with getattr(websockets, "connect")(
uri, **handle_possible_ssl(**kwargs)
) as protocol:
inst = cls(protocol, markdown=markdown)
# Stored for connection restart capability.
inst._uri = uri
inst._conn_kwargs = {**kwargs}
yield inst
[docs]
@classmethod
def server_handler(
cls: type[T],
init: ConnectionInit[T] = None,
stop_sig: _asyncio.Event = None,
manager: _ConnectionManager = None,
) -> _Callable[[_ServerConnection], _Awaitable[None]]:
"""
A wrapper for passing in a websocket handler and initializing a
connection.
"""
async def _handler(protocol: _ServerConnection) -> None:
"""A handler that runs the callers initialization function."""
conn = cls(protocol)
# Turn off connection restarts for server-side connections.
conn.env["auto_restart"] = False
if init is None or await init(conn):
if manager is not None:
# Allow the connection manager to process this connection.
await manager.queue.put(conn)
# Wait for either the connection to be disabled, or for
# the stop signal to be set.
tasks = [_asyncio.create_task(conn.disabled_event.wait())]
if stop_sig is not None:
tasks.append(_asyncio.create_task(stop_sig.wait()))
# Wait for the event and cancel the task that didn't
# complete.
_, pending = await _asyncio.wait(
tasks,
return_when=_asyncio.FIRST_COMPLETED,
)
# Cleaning up tasks is always a nightmare.
for task in pending:
task.cancel()
with _suppress(_asyncio.CancelledError):
await task
_log_exceptions(pending, logger=conn.logger)
# If there's no connection manager, just process the
# connection here.
else:
await conn.process(stop_sig=stop_sig)
return _handler
[docs]
@classmethod
@_asynccontextmanager
async def create_pair(
cls: type[T], serve_kwargs: dict[str, _Any] = None
) -> _AsyncIterator[tuple[T, T]]:
"""Obtain a connected pair of WebsocketConnection objects."""
server_conn: _Optional[T] = None
async def server_init(protocol: _ServerConnection) -> None:
"""Create one side of the connection and update the reference."""
nonlocal server_conn
assert server_conn is None
server_conn = cls(protocol)
async with _AsyncExitStack() as stack:
if serve_kwargs is None:
serve_kwargs = {}
serve_kwargs = handle_possible_ssl(client=False, **serve_kwargs)
is_ssl = "ssl" in serve_kwargs
# Start a server.
server = await stack.enter_async_context(
_serve(server_init, host="0.0.0.0", port=0, **serve_kwargs)
)
host = list(server.sockets)[0].getsockname()
client_conn = await stack.enter_async_context(
cls.client(f"ws{'s' if is_ssl else ''}://localhost:{host[1]}")
)
# Connect a client and yield both sides of the connection.
assert server_conn is not None
yield server_conn, client_conn
[docs]
@classmethod
@_asynccontextmanager
async def serve(
cls: type[T],
init: ConnectionInit[T] = None,
stop_sig: _asyncio.Event = None,
manager: _ConnectionManager = None,
**kwargs,
) -> _AsyncIterator[_Server]:
"""Serve a WebSocket server."""
kwargs = handle_possible_ssl(client=False, **kwargs)
is_ssl = "ssl" in kwargs
async with _serve(
cls.server_handler(init=init, stop_sig=stop_sig, manager=manager),
**kwargs,
) as server:
for socket in server.sockets:
LOG.info(
"Started WebSocket server listening on 'ws%s://%s'.",
"s" if is_ssl else "",
_sockname(socket),
)
yield server
[docs]
@classmethod
async def app(
cls: type[T],
stop_sig: _asyncio.Event,
init: ConnectionInit[T] = None,
manager: _ConnectionManager = None,
serving_callback: _Callable[[_Server], None] = None,
**kwargs,
) -> None:
"""Run a WebSocket-server application."""
if manager is None:
manager = _ConnectionManager()
async with cls.serve(
init=init, stop_sig=stop_sig, manager=manager, **kwargs
) as server:
if serving_callback is not None:
serving_callback(server)
LOG.info("WebSocket Application starting.")
await manager.manage(stop_sig)
LOG.info("WebSocket Application stopped.")
LOG.info("WebSocket Server closed.")
[docs]
class EchoWebsocketConnection(WebsocketConnection, _EchoConnection):
"""An echo connection for WebSocket."""
[docs]
class NullWebsocketConnection(WebsocketConnection, _NullConnection):
"""A null WebSocket connection."""