"""
A module implementing interfaces for evaluating the underlying state of
primitives.
"""
# built-in
import asyncio
from enum import auto
from math import isclose
import operator
from typing import AsyncIterator, Callable
# internal
from runtimepy.enum.registry import RuntimeIntEnum
from runtimepy.primitives.base import Primitive, T
[docs]
class Operator(RuntimeIntEnum):
"""https://docs.python.org/3/library/operator.html."""
LESS_THAN = auto()
LESS_THAN_OR_EQUAL = auto()
EQUAL = auto()
NOT_EQUAL = auto()
GREATER_THAN_OR_EQUAL = auto()
GREATER_THAN = auto()
OPERATOR_MAP = {
Operator.LESS_THAN: operator.lt,
Operator.LESS_THAN_OR_EQUAL: operator.le,
Operator.EQUAL: operator.eq,
Operator.NOT_EQUAL: operator.ne,
Operator.GREATER_THAN_OR_EQUAL: operator.ge,
Operator.GREATER_THAN: operator.gt,
}
[docs]
class EvalResult(RuntimeIntEnum):
"""A container for all possible evaluation results."""
NOT_SET = auto()
TIMEOUT = auto()
FAIL = auto()
SUCCESS = auto()
def __bool__(self) -> bool:
"""Determine success status of this instance."""
return self is EvalResult.SUCCESS
PrimitiveEvaluator = Callable[[T, T], EvalResult]
[docs]
async def evaluate(
primitive: Primitive[T], evaluator: PrimitiveEvaluator[T], timeout: float
) -> EvalResult:
"""
Evaluate a primitive within a timeout constraint and return the result.
"""
# Short-circuit path.
if evaluator(primitive.value, primitive.value):
return EvalResult.SUCCESS
event = asyncio.Event()
result = EvalResult.TIMEOUT
def change_callback(curr: T, new: T) -> None:
"""Handle underlying value changes to the provided primitive."""
# Always update the result when called.
nonlocal result
result = evaluator(curr, new)
if result:
event.set()
# Service the callback until resolved or this evaluation times out.
with primitive.callback(change_callback):
try:
await asyncio.wait_for(event.wait(), timeout)
except asyncio.TimeoutError:
pass
return result
[docs]
async def compare_latest(
primitive: Primitive[T],
lhs: T,
timeout: float,
operation: Operator = Operator.EQUAL,
) -> EvalResult:
"""
Perform a canonical comparison of the latest value with a provided value.
"""
mapped = OPERATOR_MAP[operation]
return await evaluate(
primitive,
lambda _, new: (
EvalResult.SUCCESS if mapped(lhs, new) else EvalResult.FAIL
),
timeout,
)
[docs]
async def sample_for(
primitive: Primitive[T],
timeout: float,
count: int = -1,
current: bool = True,
) -> AsyncIterator[tuple[T, int]]:
"""
Sample a primitive until timeout or 'count' samples are emitted (if 'count'
is set).
"""
sample_queue: asyncio.Queue[tuple[T, int]] = asyncio.Queue()
keep_sampling = True
samples = 0
def poll_sample_event() -> bool:
"""Poll sample count."""
nonlocal samples
samples += 1
result = samples >= count > 0
if result:
nonlocal keep_sampling
keep_sampling = False
return result
# Publish current value.
if current:
sample_queue.put_nowait((primitive.value, primitive.last_updated_ns))
poll_sample_event()
def sample(_: T, new: T) -> None:
"""Publish to change queue."""
if keep_sampling:
sample_queue.put_nowait((new, primitive.last_updated_ns))
poll_sample_event()
# Service the callback until resolved or this evaluation times out.
with primitive.callback(sample):
try:
async with asyncio.timeout(timeout):
while keep_sampling:
yield await sample_queue.get()
except asyncio.TimeoutError:
pass
# Drain queue.
while not sample_queue.empty():
yield sample_queue.get_nowait()
[docs]
class PrimitiveIsCloseMixin(Primitive[T]):
"""Adds a wait-for-isclose method."""
[docs]
async def wait_for_isclose(
self,
value: float,
timeout: float,
rel_tol: float = 1e-09,
abs_tol: float = 0.0,
) -> EvalResult:
"""Wait for this primitive to reach a specified state."""
return await evaluate(
self,
lambda _, new: (
EvalResult.SUCCESS
if isclose(
value, self.scale(new), rel_tol=rel_tol, abs_tol=abs_tol
)
else EvalResult.FAIL
),
timeout,
)