"""ReAct: the canonical observe-think-act loop.
Each turn:
1. Check budget; emit warning / exceeded.
2. Call the model with current messages + available tools.
3. Stream tokens, accumulate text + tool calls + usage.
4. If no tool calls, the model is done; break.
5. Otherwise, dispatch all tool calls in parallel through hooks
→ permissions → tool host. Append results to messages.
6. Loop.
This is the v0.1.x default behaviour, lifted verbatim out of
``Agent._loop`` and behind the :class:`Architecture` protocol.
Behaviour is identical; the refactor only changes the shape.
"""
from __future__ import annotations
import contextlib
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any
import anyio
from ..core.types import (
Event,
Message,
PermissionDecision,
Role,
ToolCall,
ToolResult,
Usage,
)
from ..security.audit import AuditLog
from .base import AgentSession, Dependencies
from .helpers import add_usage
# Module-level singleton no-op async context manager. ``contextlib.nullcontext``
# implements both the sync and async protocols (since Python 3.10), so we can
# reuse a single instance everywhere a hot path wants to *maybe* enter a
# telemetry span via ``async with (NULL_CTX if fast else tel.trace(...)):``.
_NULL_CTX: contextlib.AbstractAsyncContextManager[None] = contextlib.nullcontext()
if TYPE_CHECKING:
from ..agent.api import Agent
[docs]
class ReAct:
"""Observe-think-act in a tight loop.
The default architecture for every :class:`Agent`. Other
architectures wrap or replace this strategy; see ``Subagent.md``.
``max_turns`` overrides ``Dependencies.max_turns`` for this
architecture only — useful when wrapping ReAct inside another
architecture that sets its own per-leaf cap (Reflexion,
Plan-and-Execute, etc.). ``None`` means "use whatever the Agent
was configured with".
"""
name = "react"
def __init__(self, *, max_turns: int | None = None) -> None:
self._max_turns_override = max_turns
[docs]
def declared_workers(self) -> dict[str, Agent]:
return {}
[docs]
async def run(
self,
session: AgentSession,
deps: Dependencies,
prompt: str,
) -> AsyncIterator[Event]:
# 1. Seed context (system prompt + memory recall + user prompt).
session.messages.extend(
await _build_seed_messages(deps, session.instructions, prompt)
)
max_turns = (
self._max_turns_override
if self._max_turns_override is not None
else deps.max_turns
)
# 2. The ReAct loop.
while True:
if session.turns >= max_turns:
session.interrupted = True
session.interruption_reason = "max_turns_exceeded"
break
# Budget gate. ``NoBudget.allows_step()`` returns OK every
# time — when ``fast_budget`` is True we skip the call
# entirely.
if not deps.fast_budget:
status = await deps.budget.allows_step()
if status.blocked:
session.interrupted = True
session.interruption_reason = f"budget:{status.reason}"
if not deps.fast_telemetry:
await deps.telemetry.emit_metric(
"jeeves.budget.exceeded",
1,
session_id=session.id,
reason=status.reason,
)
yield Event.budget_exceeded(session.id, status)
break
if status.warn:
yield Event.budget_warning(session.id, status)
session.turns += 1
turn_trace: contextlib.AbstractAsyncContextManager[Any] = (
_NULL_CTX
if deps.fast_telemetry
else deps.telemetry.trace(
"jeeves.turn",
turn=session.turns,
session_id=session.id,
)
)
async with turn_trace:
# 2a. Model call.
#
# Non-streaming hot path: when nobody's reading from
# ``agent.stream()``, prefer the model adapter's
# single-shot ``complete()`` method. Skips per-token
# async-generator yields + per-chunk Event
# constructions and uses a non-streaming HTTP call
# on the wire (no SSE overhead). About 100-200 ms
# per turn faster on token-heavy responses.
#
# Streaming path: yield each ModelChunk as a
# ``model_chunk`` Event so a ``stream()`` consumer
# sees tokens as they arrive.
tool_defs = await deps.tools.list_tools()
tool_calls: list[ToolCall] = []
usage = Usage()
if not deps.streaming and hasattr(deps.model, "complete"):
model_trace: contextlib.AbstractAsyncContextManager[Any] = (
_NULL_CTX
if deps.fast_telemetry
else deps.telemetry.trace(
"jeeves.model.complete",
model=deps.model.name,
turn=session.turns,
session_id=session.id,
tool_count=len(tool_defs),
)
)
async with model_trace:
if deps.fast_runtime:
# Inline path: skip ``runtime.step`` wrapping
# (no idempotency_key derivation, no journal
# write — InProcRuntime's step is a literal
# ``await fn(*args)``).
text, tool_calls, usage, _finish_reason = (
await deps.model.complete(
session.messages,
tools=tool_defs or None,
)
)
else:
text, tool_calls, usage, _finish_reason = (
await deps.runtime.step( # type: ignore[func-returns-value]
f"model_call_{session.turns}",
deps.model.complete,
session.messages,
tools=tool_defs or None,
)
)
else:
text_parts: list[str] = []
stream_trace: contextlib.AbstractAsyncContextManager[Any] = (
_NULL_CTX
if deps.fast_telemetry
else deps.telemetry.trace(
"jeeves.model.stream",
model=deps.model.name,
turn=session.turns,
session_id=session.id,
tool_count=len(tool_defs),
)
)
async with stream_trace:
if deps.fast_runtime:
chunks = deps.model.stream(
session.messages,
tools=tool_defs or None,
)
else:
chunks = deps.runtime.stream_step(
f"model_call_{session.turns}",
deps.model.stream,
session.messages,
tools=tool_defs or None,
)
async for chunk in chunks:
yield Event.model_chunk(session.id, chunk)
if chunk.kind == "text" and chunk.text is not None:
text_parts.append(chunk.text)
elif (
chunk.kind == "tool_call"
and chunk.tool_call is not None
):
tool_calls.append(chunk.tool_call)
elif chunk.kind == "finish" and chunk.usage is not None:
usage = chunk.usage
text = "".join(text_parts)
# 2b. Update budget + telemetry + cumulative usage.
if not deps.fast_budget:
await deps.budget.consume(
tokens_in=usage.input_tokens,
tokens_out=usage.output_tokens,
cost_usd=usage.cost_usd,
)
if not deps.fast_telemetry:
await deps.telemetry.emit_metric(
"jeeves.tokens.input",
usage.input_tokens,
session_id=session.id,
model=deps.model.name,
)
await deps.telemetry.emit_metric(
"jeeves.tokens.output",
usage.output_tokens,
session_id=session.id,
model=deps.model.name,
)
if usage.cost_usd:
await deps.telemetry.emit_metric(
"jeeves.cost.usd",
usage.cost_usd,
session_id=session.id,
model=deps.model.name,
)
session.cumulative_usage = add_usage(
session.cumulative_usage, usage
)
session.output = text
session.messages.append(
Message(
role=Role.ASSISTANT,
content=text,
tool_calls=tuple(tool_calls),
)
)
# 2c. No tool calls = model is done.
if not tool_calls:
break
# 2d. Dispatch tool calls in parallel.
#
# Two paths picked from ``deps.streaming``:
#
# - **Buffered (default for ``agent.run``)**: events
# for each tool are appended to a per-call list
# while the task runs. After all tasks complete
# we yield the lists in call order. One task
# group, no memory channel, no per-call clones —
# ~25-35% faster end-to-end on tool-heavy turns
# in the JeevesAgent vs LangChain bench.
#
# - **Streaming (``agent.stream``)**: events flow
# through a memory-object channel as tasks emit
# them. Slower but preserves arrival-order
# semantics so a consumer that breaks out of the
# stream cancels long-running tools promptly.
results: list[ToolResult | None] = [None] * len(tool_calls)
if deps.streaming:
async for ev in _dispatch_streaming(
deps, session, tool_calls, results
):
yield ev
else:
events_per_call: list[list[Event]] = [
[] for _ in tool_calls
]
async with anyio.create_task_group() as tg:
for i, call in enumerate(tool_calls):
tg.start_soon(
_run_one_tool,
deps,
session,
call,
i,
results,
events_per_call[i],
)
for event_list in events_per_call:
for ev in event_list:
yield ev
# 2e. Append tool results to messages in the order calls
# were emitted (preserves model's expected ordering).
for r, c in zip(results, tool_calls, strict=True):
final = (
r
if r is not None
else ToolResult.error_(c.id, "no_result")
)
session.messages.append(
Message(
role=Role.TOOL,
content=_format_tool_message(final),
tool_call_id=final.call_id,
)
)
# ---------------------------------------------------------------------------
# Helpers (module-level so multiple architectures can reuse)
# ---------------------------------------------------------------------------
async def _build_seed_messages(
deps: Dependencies, instructions: str, prompt: str
) -> list[Message]:
"""Construct the initial message list: system + memory recall +
rehydrated session history + current user prompt.
Two layers of memory feed the model:
* **Cross-session recall** (``recall_facts`` + ``recall_episodes``)
— surfaces relevant context from OTHER conversations the same
user has had. Returned as a single SYSTEM block above the
message log, partitioned by ``deps.context.user_id``.
* **Within-session continuity** (``session_messages``) —
rehydrates this conversation's prior user/assistant turns as
real :class:`Message` history so the model sees the chat
thread, not just a recall summary. Driven by
``deps.context.session_id``: reusing the same id continues
the conversation.
Recall episodes that belong to *this same session* are filtered
out before the SYSTEM block is built — those would just
duplicate content the rehydrated message history already
carries.
"""
user_id = deps.context.user_id
session_id = deps.context.session_id
messages: list[Message] = [
Message(role=Role.SYSTEM, content=instructions),
]
blocks = await deps.memory.working()
if blocks:
block_text = "\n\n".join(b.format() for b in blocks)
messages.append(Message(role=Role.SYSTEM, content=block_text))
facts: list[Any] = []
try:
facts = await deps.memory.recall_facts(
prompt, limit=5, user_id=user_id
)
except (AttributeError, TypeError):
# 0.1.x backends without recall_facts, or older signatures
# missing ``user_id``: silent fallback.
facts = []
try:
episodes = await deps.memory.recall(
prompt, kind="episodic", limit=3, user_id=user_id
)
except TypeError:
# Backends that haven't picked up the user_id kwarg yet.
episodes = await deps.memory.recall(
prompt, kind="episodic", limit=3
)
# Cross-session recall only — drop any episode that belongs to
# this same conversation (its content will be rehydrated as a
# real chat turn below).
if session_id is not None:
episodes = [e for e in episodes if e.session_id != session_id]
recall_parts: list[str] = []
if facts:
recall_parts.append(
"Known facts:\n" + "\n".join(f"- {f.format()}" for f in facts)
)
if episodes:
recall_parts.append(
"Relevant past episodes:\n"
+ "\n".join(f"- {e.format()}" for e in episodes)
)
if recall_parts:
messages.append(
Message(role=Role.SYSTEM, content="\n\n".join(recall_parts))
)
# Rehydrate this conversation's prior turns so the model sees
# real chat history rather than relying only on semantic
# recall. Backends without persisted message logs return [] —
# in which case we fall through to the current-prompt-only path.
if session_id is not None:
try:
history = await deps.memory.session_messages(
session_id, user_id=user_id, limit=20
)
except (AttributeError, TypeError):
history = []
messages.extend(history)
messages.append(Message(role=Role.USER, content=prompt))
return messages
async def _dispatch_streaming(
deps: Dependencies,
session: AgentSession,
tool_calls: list[ToolCall],
results: list[ToolResult | None],
) -> AsyncIterator[Event]:
"""Streaming path for tool dispatch — events flow as they happen.
Each worker emits its tool_call event BEFORE running the tool
(via the channel) and its tool_result event AFTER, so a stream
consumer can break the iterator on a tool_call event and the
surrounding task group cancels the in-flight worker promptly.
"""
from anyio.streams.memory import MemoryObjectSendStream
send, receive = anyio.create_memory_object_stream[Event](
max_buffer_size=len(tool_calls) * 4
)
async def _stream_one(
slot: int,
call: ToolCall,
sender: MemoryObjectSendStream[Event],
) -> None:
async with sender:
await sender.send(Event.tool_call(session.id, call))
if not deps.fast_audit:
await _audit(
deps.audit_log,
session.id,
"model",
"tool_call",
{
"tool": call.tool,
"call_id": call.id,
"args": dict(call.args),
"destructive": call.destructive,
"turn": session.turns,
},
)
result = await _run_single_tool(
deps, call, turn=session.turns, slot=slot
)
results[slot] = result
if not deps.fast_audit:
await _audit(
deps.audit_log,
session.id,
"system",
"tool_result",
{
"tool": call.tool,
"call_id": result.call_id,
"ok": result.ok,
"denied": result.denied,
"error": result.error,
"reason": result.reason,
"turn": session.turns,
},
)
await sender.send(Event.tool_result(session.id, result))
async with anyio.create_task_group() as tg:
async with send:
for i, call in enumerate(tool_calls):
tg.start_soon(_stream_one, i, call, send.clone())
async with receive:
async for ev in receive:
yield ev
async def _run_one_tool(
deps: Dependencies,
session: AgentSession,
call: ToolCall,
slot: int,
results: list[ToolResult | None],
event_buffer: list[Event],
) -> None:
"""Per-call worker: append tool_call event, run the tool through
hooks + permissions + runtime.step, write result into
``results[slot]``, append tool_result event. Buffered into
``event_buffer`` so the caller can yield events in deterministic
call order after the parallel dispatch completes."""
event_buffer.append(Event.tool_call(session.id, call))
if not deps.fast_audit:
await _audit(
deps.audit_log,
session.id,
"model",
"tool_call",
{
"tool": call.tool,
"call_id": call.id,
"args": dict(call.args),
"destructive": call.destructive,
"turn": session.turns,
},
)
result = await _run_single_tool(deps, call, turn=session.turns, slot=slot)
results[slot] = result
if not deps.fast_audit:
await _audit(
deps.audit_log,
session.id,
"system",
"tool_result",
{
"tool": call.tool,
"call_id": result.call_id,
"ok": result.ok,
"denied": result.denied,
"error": result.error,
"reason": result.reason,
"turn": session.turns,
},
)
event_buffer.append(Event.tool_result(session.id, result))
async def _run_single_tool(
deps: Dependencies, call: ToolCall, *, turn: int, slot: int
) -> ToolResult:
"""Hooks → permission → journaled tool host call → post-hook.
Layers (hooks / permissions / runtime / telemetry) are skipped
when their corresponding ``fast_*`` flag on ``deps`` is set —
e.g. ``AllowAll`` permissions short-circuit the ``check`` call,
an empty ``HookRegistry`` short-circuits ``pre_tool`` /
``post_tool`` dispatch, and ``InProcRuntime`` lets us inline
the tool host call (skipping idempotency-key derivation).
"""
started = anyio.current_time()
result: ToolResult
tool_trace: contextlib.AbstractAsyncContextManager[Any] = (
_NULL_CTX
if deps.fast_telemetry
else deps.telemetry.trace(
"jeeves.tool",
tool=call.tool,
call_id=call.id,
turn=turn,
)
)
async with tool_trace:
if deps.fast_hooks:
hook_decision: PermissionDecision = PermissionDecision.allow_()
else:
hook_decision = await deps.hooks.pre_tool(call)
if hook_decision.deny:
result = ToolResult.denied_(
call.id, hook_decision.reason or "denied by hook"
)
else:
if deps.fast_permissions:
# AllowAll always allows — skip the dataclass round-trip.
perm: PermissionDecision = PermissionDecision.allow_()
else:
perm = await deps.permissions.check(call, context={})
if perm.deny:
result = ToolResult.denied_(
call.id, perm.reason or "denied by policy"
)
elif perm.ask and not hook_decision.allow:
result = ToolResult.denied_(
call.id,
perm.reason or "approval required; no approver",
)
else:
try:
if deps.fast_runtime:
result = await deps.tools.call(
call.tool,
call.args,
call_id=call.id,
)
else:
result = await deps.runtime.step(
f"tool_call_{turn}_{slot}",
deps.tools.call,
call.tool,
call.args,
call_id=call.id,
idempotency_key=call.idempotency_key(),
)
except Exception as exc: # noqa: BLE001
result = ToolResult.error_(call.id, str(exc))
if not deps.fast_hooks:
await deps.hooks.post_tool(call, result)
if not deps.fast_telemetry:
elapsed_ms = (anyio.current_time() - started) * 1000
await deps.telemetry.emit_metric(
"jeeves.tool.duration_ms",
elapsed_ms,
tool=call.tool,
ok=result.ok,
denied=result.denied,
)
return result
async def _audit(
audit_log: AuditLog | None,
session_id: str,
actor: str,
action: str,
payload: dict[str, Any],
) -> None:
if audit_log is None:
return
await audit_log.append(
session_id=session_id,
actor=actor,
action=action,
payload=payload,
)
def _format_tool_message(result: ToolResult) -> str:
if result.ok:
return str(result.output)
if result.denied:
return f"DENIED: {result.reason or 'no reason given'}"
return f"ERROR: {result.error or 'unknown'}"
__all__ = ["ReAct"]