"""Retry wrapper for any :class:`~jeevesagent.Model`.
:class:`RetryingModel` decorates an underlying model adapter
(``OpenAIModel``, ``AnthropicModel``, ``LiteLLMModel``, custom
implementations) with the framework's retry policy. The agent loop
sees a :class:`Model` like any other; the retry mechanics are
invisible above this layer.
What it does:
* On every :meth:`complete` / :meth:`stream` call, runs the
underlying model up to :attr:`RetryPolicy.max_attempts` times.
* Catches the underlying SDK exception, runs it through
:func:`~jeevesagent.governance.classify_model_error`.
* :class:`~jeevesagent.PermanentModelError` (auth, bad request,
content filter) is re-raised immediately without backoff.
* :class:`~jeevesagent.TransientModelError` (rate limit, 5xx,
network) is retried after a backoff computed by
:func:`~jeevesagent.governance.compute_backoff`. Provider-supplied
``Retry-After`` hints set a floor on the wait.
* Unrecognised exceptions (anything :func:`classify_model_error`
returns ``None`` for) propagate unchanged — better to let an
unknown error bubble up than silently retry it.
Streaming retries are deliberately limited: once the first chunk
has been yielded to the consumer we cannot rewind, so retries only
fire while waiting for that first chunk. Errors mid-stream
propagate.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
import anyio
from ..core.errors import (
ModelError,
PermanentModelError,
TransientModelError,
)
from ..core.types import Message, ModelChunk, ToolCall, ToolDef, Usage
from ..governance.retry import (
RetryPolicy,
classify_model_error,
compute_backoff,
)
__all__ = ["RetryingModel"]
[docs]
class RetryingModel:
"""Wraps any :class:`~jeevesagent.Model` with retry semantics.
Construction does not validate the inner model — anything that
quacks like a Model (has ``name``, ``stream``, optional
``complete``) works. The wrapper keeps a stable ``name``
matching the underlying model so telemetry and audit logs stay
consistent.
"""
def __init__(self, inner: Any, policy: RetryPolicy) -> None:
self._inner = inner
self._policy = policy
self.name: str = getattr(inner, "name", "unknown")
@property
def inner(self) -> Any:
"""The wrapped model. Useful for tests + introspection."""
return self._inner
@property
def policy(self) -> RetryPolicy:
return self._policy
[docs]
async def complete(
self,
messages: list[Message],
*,
tools: list[ToolDef] | None = None,
temperature: float = 1.0,
max_tokens: int | None = None,
) -> tuple[str, list[ToolCall], Usage, str]:
"""Single-shot completion with retry on transient failures."""
async def _do_call() -> tuple[str, list[ToolCall], Usage, str]:
result: tuple[str, list[ToolCall], Usage, str] = (
await self._inner.complete(
messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
)
)
return result
result: tuple[str, list[ToolCall], Usage, str] = (
await self._with_retry(_do_call)
)
return result
[docs]
async def stream(
self,
messages: list[Message],
*,
tools: list[ToolDef] | None = None,
temperature: float = 1.0,
max_tokens: int | None = None,
) -> AsyncIterator[ModelChunk]:
"""Streaming completion with retry-before-first-chunk.
We can't roll back chunks already yielded to the consumer,
so retry behaviour applies to errors that occur *before*
the first :class:`ModelChunk` is produced. If the underlying
``stream`` raises mid-stream the error propagates unchanged.
"""
attempt = 1
while True:
try:
iterator = self._inner.stream(
messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
)
# First chunk — if this raises we can still retry.
first_chunk: ModelChunk | None = None
try:
first_chunk = await iterator.__anext__()
except StopAsyncIteration:
return
# Past the gate; from here on errors propagate.
yield first_chunk
async for chunk in iterator:
yield chunk
return
except Exception as exc: # noqa: BLE001
classified = classify_model_error(exc)
if not isinstance(classified, TransientModelError):
if classified is not None:
raise classified from exc
raise
if attempt >= self._policy.max_attempts:
raise classified from exc
delay = compute_backoff(
self._policy,
attempt,
retry_after=classified.retry_after,
)
await anyio.sleep(delay)
attempt += 1
async def _with_retry(self, fn: Any) -> Any:
"""Generic retry loop for non-streaming calls.
Catches SDK exceptions, classifies, and either retries with
backoff (transient) or raises (permanent / unknown).
"""
attempt = 1
last_transient: TransientModelError | None = None
while True:
try:
return await fn()
except ModelError as exc:
# Already classified — handle by family.
if isinstance(exc, PermanentModelError):
raise
if not isinstance(exc, TransientModelError):
raise
last_transient = exc
if attempt >= self._policy.max_attempts:
raise
delay = compute_backoff(
self._policy,
attempt,
retry_after=exc.retry_after,
)
await anyio.sleep(delay)
attempt += 1
except Exception as exc: # noqa: BLE001
classified = classify_model_error(exc)
if classified is None:
# Unknown error — don't paper over it.
raise
if isinstance(classified, PermanentModelError):
raise classified from exc
last_transient = (
classified
if isinstance(classified, TransientModelError)
else None
)
if attempt >= self._policy.max_attempts:
if classified is not None:
raise classified from exc
raise
delay = compute_backoff(
self._policy,
attempt,
retry_after=(
last_transient.retry_after
if last_transient is not None
else None
),
)
await anyio.sleep(delay)
attempt += 1