Source code for runtimepy.net.manager
"""
A module implementing a connection manager.
"""
from __future__ import annotations
# built-in
import asyncio as _asyncio
from contextlib import suppress as _suppress
from typing import Iterator as _Iterator
from typing import Optional as _Optional
from typing import TypeVar as _TypeVar
# third-party
from vcorelib.asyncio import log_exceptions as _log_exceptions
from vcorelib.logging import LoggerMixin
from vcorelib.math import metrics_time_ns as _metrics_time_ns
# internal
from runtimepy.net.connection import Connection as _Connection
T = _TypeVar("T", bound=_Connection)
[docs]
class ConnectionManager(LoggerMixin):
"""A class for managing connection processing at runtime."""
def __init__(self) -> None:
"""Initialize this connection manager."""
super().__init__()
self.queue: _asyncio.Queue[_Connection] = _asyncio.Queue()
self._running = False
self._conns: list[_Connection] = []
@property
def num_connections(self) -> int:
"""Return the number of managed connections."""
return len(self._conns)
@property
def connection_tasks(self) -> _Iterator[_asyncio.Task[None]]:
"""Iterate over connection tasks."""
for conn in self._conns:
yield from conn.tasks
[docs]
def by_type(self, kind: type[T]) -> _Iterator[T]:
"""Iterate over connections of a specific type."""
for conn in self._conns:
if isinstance(conn, kind):
yield conn
[docs]
def reset_metrics(self) -> None:
"""Reset connection metrics."""
for conn in self._conns:
conn.metrics.reset()
[docs]
def poll_metrics(self, time_ns: int = None) -> None:
"""Poll connection metrics."""
if time_ns is None:
time_ns = _metrics_time_ns()
for conn in self._conns:
conn.metrics.poll(time_ns=time_ns)
[docs]
async def manage(self, stop_sig: _asyncio.Event) -> None:
"""Handle incoming connections until the stop signal is set."""
# Allow this method to be reentrant.
if self._running:
await stop_sig.wait()
return
self._running = True
stop_sig_task = _asyncio.create_task(stop_sig.wait())
tasks: list[_asyncio.Task[None]] = []
self._conns = []
new_conn_task: _Optional[_asyncio.Task[_Connection]] = None
while not stop_sig.is_set():
# Create a new-connection handler.
if new_conn_task is None:
# Wait for a connection to be established.
new_conn_task = _asyncio.create_task(self.queue.get())
# Wait for any task to complete.
await _asyncio.wait(
[stop_sig_task, new_conn_task] + tasks, # type: ignore
return_when=_asyncio.FIRST_COMPLETED,
)
# Filter completed tasks out of the working set.
next_tasks = _log_exceptions(tasks)
# Filter out disabled connections.
enabled = []
for conn in self._conns:
if not conn.disabled:
enabled.append(conn)
# Check if this connection should be restarted.
elif conn.auto_restart:
next_tasks.append(
_asyncio.create_task(conn.process(stop_sig=stop_sig))
)
enabled.append(conn)
self._conns = enabled
# If a new connection was made, register a task for processing
# it.
if new_conn_task.done():
new_conn = new_conn_task.result()
self._conns.append(new_conn)
next_tasks.append(
_asyncio.create_task(new_conn.process(stop_sig=stop_sig))
)
new_conn_task = None
tasks = next_tasks
with self.log_time("Shutting down", reminder=True):
# Allow existing tasks to clean up.
if new_conn_task is not None:
new_conn_task.cancel()
if tasks:
with _suppress(_asyncio.TimeoutError):
await _asyncio.wait_for(_asyncio.gather(*tasks), 1.0)
self._running = False