Source code for jeevesagent.architecture.tree_of_thoughts

"""Tree of Thoughts: branching exploration with per-node evaluation.

Yao et al. 2023 — `Tree of Thoughts: Deliberate Problem Solving with
Large Language Models <https://arxiv.org/abs/2305.10601>`_. Useful
for combinatorial reasoning, multi-step planning, math (Game of 24),
puzzle solving — anywhere a single straight-shot ReAct trajectory
would commit too early.

Pattern (BFS beam search)
-------------------------

1. **Root** is the problem statement.
2. **For each level up to ``max_depth``:**

   a. **Expand:** for every frontier node, the proposer generates
      ``branch_factor`` candidate "thoughts" (next steps toward a
      solution).
   b. **Evaluate:** the evaluator scores each candidate 0-1 (how
      promising is this branch?).
   c. **Prune:** keep only the top ``beam_width`` scored
      candidates as the next frontier.
   d. **Early exit:** if any candidate scores ``>= solved_threshold``,
      we stop early and use that branch.

3. **Best leaf wins.** The highest-scoring leaf across the whole
   tree is the final answer (its content goes to ``session.output``).

This is the "BFS-with-beam" variant — DFS with backtracking is a
follow-up. For a structured combinatorial task, BFS-beam covers most
of what users need.

Cost
----
``branch_factor × beam_width × max_depth × 2`` model calls (one
proposer + one evaluator per candidate). With defaults
``(3, 2, 3)`` that's 36 calls. Reserve ToT for problems where the
search structure earns the cost — math/planning tasks where ReAct
visibly meanders.

Strengths
---------
* **Explicit search tree.** Every candidate, score, and decision is
  observable through ``architecture_event`` events.
* **Composable.** Wrap inside :class:`Reflexion` to learn which
  evaluation patterns predict real success.
* **Replay-correct.** Each proposer / evaluator call is a named
  ``runtime.step``, so journaled runtimes replay deterministically.

Weaknesses
----------
* **Expensive.** 30-50× a single ReAct turn for typical settings.
* **Evaluator-quality bound.** A weak evaluator picks weak branches
  and the search wastes budget on dead ends.
* **Domain-specific.** Branch-and-evaluate makes sense for
  combinatorial problems; for open-ended writing tasks, use
  Self-Refine or Actor-Critic.
"""

from __future__ import annotations

from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

import anyio
from pydantic import BaseModel

from ..core.ids import new_id
from ..core.types import Event, Message, Role, Usage
from .base import AgentSession, Dependencies
from .helpers import add_usage, parse_score, text_only_model_call

if TYPE_CHECKING:
    from ..agent.api import Agent


DEFAULT_PROPOSER_PROMPT = """\
You are exploring possible reasoning paths to solve a problem.

Given the problem and any prior steps, propose ONE next step (a
"thought") toward a solution. A thought can be a sub-step,
intermediate calculation, sub-decision, or partial answer.

Output only the thought itself — concise, one paragraph at most.
Do not number it; do not preface with "Thought:".
"""


DEFAULT_EVALUATOR_PROMPT = """\
You evaluate a candidate reasoning step. Given the original problem
and the proposed thought, score how promising this thought is for
arriving at the correct solution.

Output exactly one line:
score: <number between 0 and 1>

Then optionally one line of brief justification. The first line
must match the score format exactly so it can be parsed.

- 1.0 = this thought is correct and final / will obviously lead to a
  correct answer
- 0.7-0.9 = strong direction, likely correct
- 0.4-0.6 = plausible but uncertain
- 0.0-0.3 = wrong direction or contradicts the problem
"""


[docs] class ThoughtNode(BaseModel): """One node in the Tree-of-Thoughts search tree. Children are stored implicitly (each node has a ``parent_id``). The full tree is reconstructable from the node list ToT keeps in its session metadata. """ id: str parent_id: str | None content: str score: float = 0.0 depth: int
[docs] class TreeOfThoughts: """Branch + evaluate + prune. BFS beam search over thoughts.""" name = "tree-of-thoughts" def __init__( self, *, branch_factor: int = 3, max_depth: int = 3, beam_width: int = 2, solved_threshold: float = 1.0, min_score: float = 0.0, parallel: bool = True, proposer_prompt: str | None = None, evaluator_prompt: str | None = None, ) -> None: if branch_factor < 1: raise ValueError("branch_factor must be >= 1") if max_depth < 1: raise ValueError("max_depth must be >= 1") if beam_width < 1: raise ValueError("beam_width must be >= 1") if not 0.0 <= solved_threshold <= 1.0: raise ValueError( "solved_threshold must be in [0.0, 1.0]" ) if not 0.0 <= min_score <= 1.0: raise ValueError("min_score must be in [0.0, 1.0]") self._branch_factor = branch_factor self._max_depth = max_depth self._beam_width = beam_width self._solved_threshold = solved_threshold # Floor below which a candidate is dropped REGARDLESS of beam # capacity. Lets bad branches die quickly instead of riding # along just because the beam has room. 0.0 = legacy behavior # (no floor). self._min_score = min_score # Run proposer + evaluator calls within a level concurrently # via anyio.create_task_group. Pure speedup — branch_factor * # beam_width independent calls are now wall-clock parallel # instead of sequential. Disable for deterministic test # ordering or when your model provider has tight rate limits. self._parallel = parallel self._proposer_prompt = ( proposer_prompt or DEFAULT_PROPOSER_PROMPT ) self._evaluator_prompt = ( evaluator_prompt or DEFAULT_EVALUATOR_PROMPT )
[docs] def declared_workers(self) -> dict[str, Agent]: return {}
[docs] async def run( self, session: AgentSession, deps: Dependencies, prompt: str, ) -> AsyncIterator[Event]: # Root represents the problem itself; depth 0; no model call. root = ThoughtNode( id=new_id("thot"), parent_id=None, content=prompt, score=1.0, # root is "perfect" by definition depth=0, ) all_nodes: list[ThoughtNode] = [root] frontier: list[ThoughtNode] = [root] yield Event.architecture_event( session.id, "tot.started", branch_factor=self._branch_factor, beam_width=self._beam_width, max_depth=self._max_depth, ) for depth in range(1, self._max_depth + 1): status = await deps.budget.allows_step() if status.blocked: session.interrupted = True session.interruption_reason = ( f"budget:{status.reason}" ) yield Event.budget_exceeded(session.id, status) break if status.warn: yield Event.budget_warning(session.id, status) yield Event.architecture_event( session.id, "tot.level_started", depth=depth, frontier_size=len(frontier), ) # === Expand: generate branch_factor candidates per node === # # Parallel mode runs every (parent × k) proposer call # concurrently — independent calls, big wall-clock win # at moderate fan-out. Sequential mode preserves # deterministic ordering for tests / strict rate limits. candidates: list[ThoughtNode] = [] propose_jobs = [ (parent, k) for parent in frontier for k in range(self._branch_factor) ] propose_results: list[tuple[str, Usage] | None] = [ None ] * len(propose_jobs) # B023 false positive: the task_group below joins on # all spawned tasks before the for-loop advances, so the # captured ``depth`` / ``propose_results`` are stable # for the closure's entire lifetime. async def _propose_one( # noqa: B023 idx: int, parent: ThoughtNode, k: int ) -> None: chain = _chain_to_root(all_nodes, parent) msgs = _proposer_messages( self._proposer_prompt, prompt, chain ) text, usage = await text_only_model_call( deps, f"tot_propose_d{depth}_p{parent.id}_k{k}", # noqa: B023 msgs, ) propose_results[idx] = (text, usage) # noqa: B023 if self._parallel: async with anyio.create_task_group() as tg: for idx, (parent, k) in enumerate(propose_jobs): tg.start_soon(_propose_one, idx, parent, k) else: for idx, (parent, k) in enumerate(propose_jobs): await _propose_one(idx, parent, k) for (parent, _k), pr in zip( propose_jobs, propose_results, strict=True ): assert pr is not None text, usage = pr await deps.budget.consume( tokens_in=usage.input_tokens, tokens_out=usage.output_tokens, cost_usd=usage.cost_usd, ) session.cumulative_usage = add_usage( session.cumulative_usage, usage ) session.turns += 1 candidate = ThoughtNode( id=new_id("thot"), parent_id=parent.id, content=text.strip(), depth=depth, ) candidates.append(candidate) all_nodes.append(candidate) yield Event.architecture_event( session.id, "tot.proposed", depth=depth, node_id=candidate.id, parent_id=parent.id, content=candidate.content[:200], ) # === Evaluate every candidate (parallel where possible) === eval_results: list[tuple[float, Usage] | None] = [ None ] * len(candidates) # Same B023-safe pattern as the proposer task group above. async def _eval_one(idx: int, cand: ThoughtNode) -> None: # noqa: B023 chain = _chain_to_root(all_nodes, cand) msgs = _evaluator_messages( self._evaluator_prompt, prompt, chain, cand ) text, usage = await text_only_model_call( deps, f"tot_eval_d{depth}_n{cand.id}", msgs # noqa: B023 ) eval_results[idx] = (parse_score(text), usage) # noqa: B023 if self._parallel: async with anyio.create_task_group() as tg: for idx, cand in enumerate(candidates): tg.start_soon(_eval_one, idx, cand) else: for idx, cand in enumerate(candidates): await _eval_one(idx, cand) for cand, er in zip( candidates, eval_results, strict=True ): assert er is not None score, usage = er await deps.budget.consume( tokens_in=usage.input_tokens, tokens_out=usage.output_tokens, cost_usd=usage.cost_usd, ) session.cumulative_usage = add_usage( session.cumulative_usage, usage ) session.turns += 1 cand.score = score yield Event.architecture_event( session.id, "tot.evaluated", depth=depth, node_id=cand.id, score=cand.score, ) # === Prune: top beam_width by score, AND drop anything # below the min_score floor. The floor lets a clearly- # losing branch die immediately even if the beam has # room — saves the next level's compute. === candidates.sort(key=lambda n: n.score, reverse=True) survivors = [ c for c in candidates if c.score >= self._min_score ] frontier = survivors[: self._beam_width] n_pruned_floor = len(candidates) - len(survivors) yield Event.architecture_event( session.id, "tot.pruned", depth=depth, kept=[n.id for n in frontier], kept_scores=[n.score for n in frontier], pruned_below_floor=n_pruned_floor, ) # === Early exit if any candidate is "solved" === if frontier and frontier[0].score >= self._solved_threshold: yield Event.architecture_event( session.id, "tot.solved", depth=depth, node_id=frontier[0].id, score=frontier[0].score, ) break if not frontier: # Beam went empty (shouldn't happen unless candidates # was empty too). Bail with whatever's best so far. yield Event.architecture_event( session.id, "tot.empty_beam", depth=depth, ) break # Pick the best non-root node we've seen across the whole tree. non_root = [n for n in all_nodes if n.parent_id is not None] if not non_root: session.output = "" yield Event.architecture_event( session.id, "tot.no_thoughts", total_nodes=len(all_nodes), ) return best = max(non_root, key=lambda n: n.score) session.output = best.content # Stash the full tree on session.metadata so consumers can # render the search tree post-hoc. session.metadata["tot_nodes"] = [ n.model_dump() for n in all_nodes ] session.metadata["tot_winner_id"] = best.id yield Event.architecture_event( session.id, "tot.completed", winner_id=best.id, winner_score=best.score, total_nodes=len(all_nodes), )
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _chain_to_root( all_nodes: list[ThoughtNode], leaf: ThoughtNode ) -> list[ThoughtNode]: """Reconstruct the chain from root to ``leaf`` by parent pointers. Returns root first, leaf last. """ by_id = {n.id: n for n in all_nodes} chain: list[ThoughtNode] = [] cursor: ThoughtNode | None = leaf while cursor is not None: chain.append(cursor) if cursor.parent_id is None: break cursor = by_id.get(cursor.parent_id) return list(reversed(chain)) def _proposer_messages( system_prompt: str, problem: str, chain: list[ThoughtNode] ) -> list[Message]: """Build messages for a proposer call. The chain from root to current frontier node provides the "prior steps so far"; the proposer extends with one new thought. """ # Drop the root (which holds the original prompt) since we send # the prompt explicitly. prior_steps = [n for n in chain if n.parent_id is not None] prior_text = ( "\n".join( f"Step {i + 1}: {n.content}" for i, n in enumerate(prior_steps) ) if prior_steps else "(no prior steps yet — propose the first one)" ) return [ Message(role=Role.SYSTEM, content=system_prompt), Message( role=Role.USER, content=( f"Problem:\n{problem}\n\n" f"Prior steps:\n{prior_text}\n\n" f"Propose ONE next step." ), ), ] def _evaluator_messages( system_prompt: str, problem: str, chain: list[ThoughtNode], candidate: ThoughtNode, ) -> list[Message]: """Build messages for an evaluator call. The chain shows prior steps; the candidate is the new step being evaluated. """ # Chain includes the candidate at the end; the prior chain is # everything before the candidate. prior = [ n for n in chain if n.parent_id is not None and n.id != candidate.id ] prior_text = ( "\n".join( f"Step {i + 1}: {n.content}" for i, n in enumerate(prior) ) if prior else "(none)" ) return [ Message(role=Role.SYSTEM, content=system_prompt), Message( role=Role.USER, content=( f"Problem:\n{problem}\n\n" f"Prior steps:\n{prior_text}\n\n" f"Candidate next step:\n{candidate.content}\n\n" f"Score this candidate." ), ), ]