"""
A module implementing an application information interface.
"""
# built-in
import asyncio as _asyncio
from contextlib import AsyncExitStack as _AsyncExitStack
from contextlib import contextmanager
from dataclasses import dataclass
from importlib import import_module
from logging import getLogger as _getLogger
from re import compile as _compile
from typing import Awaitable as _Awaitable
from typing import Callable as _Callable
from typing import Iterator as _Iterator
from typing import MutableMapping as _MutableMapping
from typing import Optional
from typing import TYPE_CHECKING, Any
from typing import TypeVar as _TypeVar
from typing import Union as _Union
from typing import cast
# third-party
from vcorelib.io.types import JsonObject as _JsonObject
from vcorelib.logging import LoggerMixin as _LoggerMixin
from vcorelib.logging import LoggerType as _LoggerType
from vcorelib.names import import_str_and_item
from vcorelib.namespace import Namespace as _Namespace
# internal
from runtimepy.channel.environment.sample import poll_sample_env, sample_env
from runtimepy.codec.protocol import Protocol, ProtocolFactory
from runtimepy.mapping import DEFAULT_PATTERN
from runtimepy.mixins.trig import TrigMixin
from runtimepy.net.arbiter.result import OverallResult, results
from runtimepy.net.connection import Connection as _Connection
from runtimepy.net.manager import ConnectionManager
from runtimepy.primitives import Uint32
from runtimepy.primitives.array import PrimitiveArray
from runtimepy.primitives.byte_order import DEFAULT_BYTE_ORDER, ByteOrder
from runtimepy.struct import RuntimeStructBase
from runtimepy.struct import StructMap as _StructMap
from runtimepy.task import PeriodicTask, PeriodicTaskManager
from runtimepy.tui.mixin import TuiMixin
if TYPE_CHECKING:
from runtimepy.subprocess.peer import (
RuntimepyPeer as _RuntimepyPeer, # pragma: nocover
)
ConnectionMap = _MutableMapping[str, _Connection]
T = _TypeVar("T", bound=_Connection)
V = _TypeVar("V", bound=PeriodicTask)
Z = _TypeVar("Z")
[docs]
class RuntimeStruct(RuntimeStructBase):
"""A class implementing a base runtime structure."""
app: "AppInfo"
array: PrimitiveArray
recv: Optional[Protocol]
send: Optional[Protocol]
byte_order: ByteOrder = DEFAULT_BYTE_ORDER
# Set this for structs to automatically be polled when the application is
# going down.
final_poll = False
[docs]
def init_env(self) -> None:
"""Initialize this environment."""
[docs]
async def async_init_env(self) -> None:
"""Initialize this environment."""
@contextmanager
def _final_poll(self) -> _Iterator[None]:
"""Poll when the context ends."""
try:
yield
finally:
self.poll()
[docs]
async def build(self, app: "AppInfo", **kwargs) -> None:
"""Build a struct instance's channel environment."""
self.app = app
if self.final_poll:
self.app.stack.enter_context(self._final_poll())
byte_order = type(self).byte_order
self.recv = None
self.send = None
if "protocol_factory" in self.config:
module, name = import_str_and_item(
cast(str, self.config["protocol_factory"])
)
factory: type[ProtocolFactory] = getattr(
import_module(module), name
)
if ProtocolFactory in factory.__bases__:
self.recv = factory.singleton()
byte_order = self.recv.byte_order
if self.config.get("control", True):
with self.env.names_pushed("rx"):
self.env.register_protocol(self.recv, False)
self.send = factory.instance()
with self.env.names_pushed("tx"):
self.env.register_protocol(self.send, True)
else:
self.env.register_protocol(self.recv, False)
self.init_env()
await self.async_init_env()
self.update_byte_order(byte_order, **kwargs)
[docs]
def update_byte_order(self, byte_order: ByteOrder, **kwargs) -> None:
"""Update the over-the-wire byte order for this struct."""
if self.recv is None:
self.array = self.env.array(byte_order=byte_order, **kwargs).array
else:
# Can't change byte order at runtime in this configuration.
assert byte_order == self.recv.byte_order
self.array = self.recv
W = _TypeVar("W", bound=RuntimeStruct)
[docs]
class TrigStruct(RuntimeStruct, TrigMixin):
"""A simple trig struct."""
iterations: Uint32
[docs]
def init_env(self) -> None:
"""Initialize this sample environment."""
TrigMixin.__init__(self, self.env)
self.iterations = Uint32()
self.env.int_channel("iterations", self.iterations, commandable=True)
[docs]
def poll(self) -> None:
"""
A method that other runtime entities can call to perform canonical
updates to this struct's environment.
"""
# Pylint bug?
self.dispatch_trig(self.iterations.value) # pylint: disable=no-member
self.iterations.value += 1 # pylint: disable=no-member
[docs]
class SampleStruct(TrigStruct):
"""A sample runtime structure."""
# For fun.
final_poll = True
[docs]
def init_env(self) -> None:
"""Initialize this sample environment."""
super().init_env()
sample_env(self.env)
[docs]
def poll(self) -> None:
"""
A method that other runtime entities can call to perform canonical
updates to this struct's environment.
"""
super().poll()
poll_sample_env(self.env)
[docs]
@dataclass
class AppInfo(_LoggerMixin):
"""References provided to network applications."""
# A logger for applications to use.
logger: _LoggerType
# An exit stack which can be used to ensure certain application resources
# are cleaned up.
stack: _AsyncExitStack
# A connection map (names to instances).
connections: ConnectionMap
conn_manager: ConnectionManager
# Connection names.
names: _Namespace
# A signal that, when set, indicates the application to begin shutting
# down.
stop: _asyncio.Event
# Configuration data that may be specified in a configuration file.
config: _JsonObject
ports: dict[str, int]
tui: TuiMixin
tasks: dict[str, PeriodicTask]
task_manager: PeriodicTaskManager[Any]
# Keep track of application state.
results: OverallResult
# A name-to-struct mapping.
structs: _StructMap
# A name-to-peer mapping.
peers: dict[str, "_RuntimepyPeer"]
[docs]
def with_new_logger(self, name: str) -> "AppInfo":
"""Get a copy of this AppInfo instance, but with a new logger."""
return AppInfo(
_getLogger(name),
self.stack,
self.connections,
self.conn_manager,
self.names,
self.stop,
self.config,
self.ports,
self.tui,
self.tasks,
self.task_manager,
self.results,
self.structs,
self.peers,
)
[docs]
def search(
self,
*names: str,
pattern: str = DEFAULT_PATTERN,
kind: type[T] = _Connection, # type: ignore
) -> _Iterator[T]:
"""
Get all connections that are matching a naming convention or are
specific kind (or both).
"""
seen: set[T] = set()
for name in self.names.search(*names, pattern=pattern):
conn = self.connections[name]
if isinstance(conn, kind):
yield conn
seen.add(conn)
# Also check the connection manager (for server connections) if the
# default pattern is used.
if pattern == DEFAULT_PATTERN:
for conn in self.conn_manager.by_type(kind):
if conn not in seen:
yield conn
seen.add(conn)
[docs]
def search_tasks(self, kind: type[V], pattern: str = ".*") -> _Iterator[V]:
"""Search for tasks by type or pattern."""
compiled = _compile(pattern)
for name, task in self.tasks.items():
if compiled.search(name) is not None and isinstance(task, kind):
yield task
[docs]
def search_structs(
self, kind: type[W], pattern: str = ".*"
) -> _Iterator[W]:
"""Search for structs by type or name."""
compiled = _compile(pattern)
for name, task in self.structs.items():
if compiled.search(name) is not None and isinstance(task, kind):
yield task
[docs]
def single(
self,
*names: str,
pattern: str = ".*",
kind: type[T] = _Connection, # type: ignore
) -> T:
"""Search for a single node."""
result = list(self.search(*names, pattern=pattern, kind=kind))
assert len(result) == 1, result
return result[0]
[docs]
async def all_finalized(self) -> None:
"""Wait for all tasks and connections to be finalized."""
# Wait for all connections and tasks to be finalized.
await _asyncio.gather(
*(
x.env.wait_finalized()
for x in list(self.connections.values())
+ list(self.tasks.values())
)
)
[docs]
def exceptions(self) -> _Iterator[Exception]:
"""Iterate over exceptions raised by the application."""
for stage in self.results:
for result in stage:
if result.exception is not None:
yield result.exception
@property
def raised_exception(self) -> bool:
"""Determine if the application raised any exception."""
return bool(list(self.exceptions()))
[docs]
def result(self, logger: _LoggerType = None) -> bool:
"""Get the overall boolean result for the application."""
if logger is not None and self.raised_exception:
logger.error("An exception was not caught at runtime!")
return results(self.results, logger=logger)
[docs]
def original_config(self) -> dict[str, Any]:
"""
Re-assemble a dictionary closer to the original configuration data
(than the .config attribute).
"""
result: dict[str, Any] = {"config": {}}
for key, val in self.config.items():
if key != "root":
result["config"][key] = val
for key, val in self.config.get("root", {}).items(): # type: ignore
if key != "config":
result[key] = val
return result
[docs]
def config_param(self, key: str, default: Z, strict: bool = False) -> Z:
"""Attempt to get a configuration parameter."""
config: dict[str, Z] = self.config["root"].setdefault( # type: ignore
"config",
{},
)
if strict:
assert key in config, (key, config)
return config.get(key, default)
NetworkApplication = _Callable[[AppInfo], _Awaitable[int]]
NetworkApplicationlike = _Union[NetworkApplication, list[NetworkApplication]]
ArbiterApps = list[list[NetworkApplication]]