Metadata-Version: 2.4
Name: treequest
Version: 0.3.2
Summary: A Tree Search Library with Flexible API for LLM Inference-Time Scaling
Author-email: Kou Misaki <kou.misaki@sakana.ai>, Takuya Akiba <takiba@sakana.ai>
License-File: LICENSE
Requires-Python: >=3.11
Requires-Dist: joblib>=1.5.2
Requires-Dist: pyyaml>=6.0.0
Requires-Dist: scipy>=1.0.0
Requires-Dist: ulid-py>=1.1.0
Provides-Extra: abmcts-m
Requires-Dist: jax>=0.5.2; extra == 'abmcts-m'
Requires-Dist: numpy>=1.26.4; extra == 'abmcts-m'
Requires-Dist: numpyro>=0.17.0; extra == 'abmcts-m'
Requires-Dist: packaging>=23; extra == 'abmcts-m'
Requires-Dist: pandas>=2.2.3; extra == 'abmcts-m'
Requires-Dist: pymc>=5.21.0; extra == 'abmcts-m'
Requires-Dist: xarray>=2025.1.2; extra == 'abmcts-m'
Provides-Extra: all
Requires-Dist: graphviz>=0.20.3; extra == 'all'
Requires-Dist: jax>=0.5.2; extra == 'all'
Requires-Dist: jinja2>=3.1.0; extra == 'all'
Requires-Dist: numpy>=1.26.4; extra == 'all'
Requires-Dist: numpyro>=0.17.0; extra == 'all'
Requires-Dist: packaging>=23; extra == 'all'
Requires-Dist: pandas>=2.2.3; extra == 'all'
Requires-Dist: pillow>=10.0.0; extra == 'all'
Requires-Dist: pymc>=5.21.0; extra == 'all'
Requires-Dist: xarray>=2025.1.2; extra == 'all'
Provides-Extra: vis
Requires-Dist: graphviz>=0.20.3; extra == 'vis'
Requires-Dist: jinja2>=3.1.0; extra == 'vis'
Requires-Dist: pillow>=10.0.0; extra == 'vis'
Description-Content-Type: text/markdown

# TreeQuest
[![Python](https://img.shields.io/badge/python-3.11-blue)](https://www.python.org)
[![GitHub license](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://www.apache.org/licenses/LICENSE-2.0)
[![Checks status](https://github.com/SakanaAI/treequest/actions/workflows/checks.yml/badge.svg?branch=main)](https://github.com/SakanaAI/treequest/actions)
[![Tests status](https://github.com/SakanaAI/treequest/actions/workflows/tests.yml/badge.svg?branch=main)](https://github.com/SakanaAI/treequest/actions)

[![arXiv](https://img.shields.io/badge/arXiv-2503.04412-B31B1B?logo=data:image/svg+xml;base64,PHN2ZyBpZD0ibG9nb21hcmsiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgdmlld0JveD0iMCAwIDE3LjczMiAyNC4yNjkiPjxnIGlkPSJ0aW55Ij48cGF0aCBkPSJNNTczLjU0OSwyODAuOTE2bDIuMjY2LDIuNzM4LDYuNjc0LTcuODRjLjM1My0uNDcuNTItLjcxNy4zNTMtMS4xMTdhMS4yMTgsMS4yMTgsMCwwLDAtMS4wNjEtLjc0OGgwYS45NTMuOTUzLDAsMCwwLS43MTIuMjYyWiIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoLTU2Ni45ODQgLTI3MS41NDgpIiBmaWxsPSIjYmRiOWI0Ii8+PHBhdGggZD0iTTU3OS41MjUsMjgyLjIyNWwtMTAuNjA2LTEwLjE3NGExLjQxMywxLjQxMywwLDAsMC0uODM0LS41LDEuMDksMS4wOSwwLDAsMC0xLjAyNy42NmMtLjE2Ny40LS4wNDcuNjgxLjMxOSwxLjIwNmw4LjQ0LDEwLjI0MmgwbC02LjI4Miw3LjcxNmExLjMzNiwxLjMzNiwwLDAsMC0uMzIzLDEuMywxLjExNCwxLjExNCwwLDAsMCwxLjA0LjY5QS45OTIuOTkyLDAsMCwwLDU3MSwyOTNsOC41MTktNy45MkExLjkyNCwxLjkyNCwwLDAsMCw1NzkuNTI1LDI4Mi4yMjVaIiB0cmFuc2Zvcm09InRyYW5zbGF0ZSgtNTY2Ljk4NCAtMjcxLjU0OCkiIGZpbGw9IiNiMzFiMWIiLz48cGF0aCBkPSJNNTg0LjMyLDI5My45MTJsLTguNTI1LTEwLjI3NSwwLDBMNTczLjUzLDI4MC45bC0xLjM4OSwxLjI1NGEyLjA2MywyLjA2MywwLDAsMCwwLDIuOTY1bDEwLjgxMiwxMC40MTlhLjkyNS45MjUsMCwwLDAsLjc0Mi4yODIsMS4wMzksMS4wMzksMCwwLDAsLjk1My0uNjY3QTEuMjYxLDEuMjYxLDAsMCwwLDU4NC4zMiwyOTMuOTEyWiIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoLTU2Ni45ODQgLTI3MS41NDgpIiBmaWxsPSIjYmRiOWI0Ii8+PC9nPjwvc3ZnPg==)](https://arxiv.org/abs/2503.04412)
[![Blog](https://img.shields.io/badge/Blog-Sakana%20AI-blue?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAFXElEQVRoge2XW4yV1RXHf+s751v7zDlnwAnMKNgRCOUmBi/BYipNG2MTrKaRanyxqdUH+sCTb5poYrz0og8afWjqQ9OkCbZAkza0pU0sA0RroiZ4GQtmZphvf+coUGC4zIU5l+9bfZghJQPMnDnDmCY9v8dvr732+q9v77XXhhYtWrRo0aJFi/9fZD6dG2SPQqFYIJckuBpkAyMwsERIMgHV0QyV3DCjy6EiYLNd45oLMMgcVW7OBtwRpKxLYSWwFOhEKABqkIoxanAW4YSkRAR8jtCbZuhdPsqJRsVcMwG9oAuU+0zYjrEO6ADyF9cQOGcwBIwahCIsFKPTIDPpoopwHuOMwadBwJ7ucXYIVOdNgIHEcF3i2JwxngXuNKgJnBLwQE8iHKTCR8u5PKvHoTCaZ10m4S4x7gXWAx0CCw2yCE8uq/DavAj4EMJOxxYznhDYAgQI7xrsDVL2S43ebrjQqD8DKeVZYnW+CfwUWCWw9aYqf2w2xqtSdqyOHLu8ctorqVfMK/88WuB6g2AuvmPHSq986ZULR6B9JvuGFzMIjha4PgrZnsB+MR5GqItwEBgGlgVVbhRI5yKAiW20BOHA2gm/09KQgCPQXsrxo0ydXSK8gZEX2IHwqDh+gPAJsEgCNtvcz9UDAAJ7GrHPzmRQCvmGwcuWcgfQjvAZ8FRbhQOdMMw4lBx/SuFuMW49Bm3AWDPB94ET2GowXDMOzlnAgGNVauwEugXKwCtJhVdWwPildkmGt6UOqbByrEiRkeYEOGWrGYsRevIhx6cvoA0IcMpQWmW3pWRS+O37NT5+BJKpdjbGYZSaGF1tCa6Z4I9AuwnbMOoYB24Y4Uwj86YV8LVhTpfg2QRsatYvZQWMR8KQQKFmM2/LK9Hm+A7GbcDJxNgrUG9k3oyHuBsuXBq8QWYQciVoM9DnJn00uuCVGICFGD8EOgz+saLGB43ObShb/UW6tMbtkrImEm4MhAUpZD2MP26c/nFATMoCgy/CYPZCQuV7BvchHEvg+dk0ddOWvMEcyzPGdjMeZKK3aQd0ipkxcePmBYZSeHVkIa/dcpKRRgLod3w9hHcwFoux7aYav240+GkFlJVH6sIbYnQBI8ApjH6EQwhlJtrfvMEaM7YILJsUIwhejKdzVfZ0cXUhcZ6lVmc3sCmAPxSrbOuAs7MRcNUtlAj3yMSB/D3wZwt4d9k4fupNaxBEyg5gqQh7MRzGvQZvjik7jwW8sGSc6LIEwaKkzs+AjcCRNMsvOqqzC35avoTFJceqPlgwnZ1XbvZK5B3H4xzfKsMiH/KYdxzzShIp7w0qa6eIznlll1cqXjlTDrlrrjd4UxhkY+WZSKnGjr8bhJPfpZTj25FyeLLR+0uUZ4lBUHas9o59Xkm9o1TOcc9XHvhFSsqGSPncK6MDOTZPHR8I2eiVQ7FSjx2/LCmPxsqHk3/msHc80NNgJbzmlKDNOw56JY2VF6/WRpdCNsWO2CuVSDnnFYsdPQOO1Tv/+xprmmbUS3+RzqTKrwLjbhP2acjrUr38cJ8osrha4XabKL0qwlkx3hyr8HwjrfK84JX1XtkdKdVI+bjUxqapB/DfUBzM8Vjs2D/54DkXKW95x3d7L79HvhoMst7xk0iJYqUeKR994VgzNfg4ZKMP2Rcp571i3vGv2PH9marZvDOorPVKX6Scj5S3+ot0XRwzCKIcK3yOn3tl1CtJrMSDynOliffBvNFw7e0DFyoPGVha4K8rz3Du4ljkuD8wXjLYAAyZ8Lsg5TfdNQ7JFdrv/zki5QWv1L3jb3HInfOd9WtOD2TLBW7to7nHTIsWLVq0aNEs/wHPLhDZrKccngAAAABJRU5ErkJggg==)](https://sakana.ai/ab-mcts/)

![AB-MCTS](https://raw.githubusercontent.com/SakanaAI/treequest/main/images/figure1.png)

A flexible answer tree search library featuring **AB-MCTS**, useful for (but not limited to) LLM inference-time scaling.

## Quick Start
```python
import random
from pathlib import Path

import treequest as tq

# Each node is associated with a user-definable `state`.
State = str

# 1. Define a function to be used for node generation.
def generate(parent_state: State | None) -> tuple[State, float]:
    """Generates new states and scores based on the parent state."""
    if parent_state is None: # None represents the expansion from root.
        new_state = "Initial state"
    else:
        new_state = f"State after {parent_state}"

    score = random.random() # A score for the new state; It should be normalized to the [0, 1] range.
    return new_state, score

# 2. Instantiate the algorithm and a search tree object.
algo = tq.ABMCTSA()
search_tree = algo.init_tree()

# 3. Run the search with a generation budget (10 in this case).
for _ in range(10):
    search_tree = algo.step(search_tree, {'Action A': generate})

# 4. Extract the best score and state.
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")

# 5. Visualize the search tree.
output_file_basename = Path("ab_mcts_a_search_tree")
tq.render(search_tree, output_file_basename, format="html")  # Generates `ab_mcts_a_search_tree.html`
```

Alternatively, you can use an ask–tell interface with batched AB-MCTS sampling steps:
```python
import random
import treequest as tq

State = str

def generate(parent_state: State | None) -> tuple[State, float]:
    ...

generate_fns = {"Action A": generate}
actions = list(generate_fns.keys())

# We use batch_size=5 here
batch_size = 5

# It runs AB-MCTS sampling step with 5 processes in parallel
algo = tq.ABMCTSM(max_process_workers=batch_size)
search_tree = algo.init_tree()

total_budget = 50
num_steps = total_budget // batch_size
for _ in range(num_steps):
    # ask_batch returns a list of `Trial` object, which has action, parent_state and trial_id attrs
    search_tree, trials = algo.ask_batch(search_tree, batch_size, actions)

    for trial in trials:
        result = generate_fns[trial.action](trial.parent_state)
        # Call tell method with trial_id to update search_tree
        search_tree = algo.tell(search_tree, trial.trial_id, result)

best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
```

In particular for AB-MCTS-M, each `step` call can be slow. If you encounter slow execution, prefer `ask_batch` over `step`.
Please note that using a large `batch_size` can skew the search-tree shape (i.e., the tree may become too wide), so it is best to avoid overly large `batch_size`, see [PROFILING.md](./docs/PROFILING.md) for example trees.
We recommend `batch_size<=5` as a starting point.

## Features
- Easy-to-use API with customizable node generation and node scoring logic.
- **AB-MCTS-A** and **AB-MCTS-M**, as well as **Multi-LLM AB-MCTS** support (See [our paper](https://arxiv.org/abs/2503.04412) for algorithm details).
- Checkpointing and resuming searches.

## Installation
First, install [`uv`](https://github.com/astral-sh/uv?tab=readme-ov-file#installation). Then you can install TreeQuest with the following command:
```bash
uv add "treequest[all]"
```

Alternatively, you can use pip to install TreeQuest:
```bash
pip install "treequest[all]"
```

There are optional dependencies for ABMCTS-M and visualization features. You can install them with:
```bash
# with all dependencies
uv add "treequest[all]"

# without dependencies
uv add treequest

# with individual dependencies
uv add "treequest[abmcts-m]"
uv add "treequest[vis]"
```

## Usage
### Using an LLM as a Node Generator
You can use any object as a node state. You only need to define a generating function that returns a `(state, score)` tuple and takes the parent state as an argument:
```python
import dataclasses

import treequest as tq

@dataclasses.dataclass
class State:
    llm_answer: str
    score: float

def generate(parent_state: State | None) -> tuple[State, float]:
    """Generate a new node by calling an LLM."""
    if parent_state is None:
        state = initial_generation()
    else:
        state = refine_answer(parent_state.llm_answer, parent_state.score)

    return state, state.score

def initial_generation() -> State:
    """
    Call LLM API to generate an initial answer.
    """
    ...

def refine_answer(llm_answer: str, score: float) -> State:
    """
    Call LLM API to refine an answer.
    """
    ...


algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
    search_tree = algo.step(search_tree, {'Action Label': generate})
    # Logging best node during the search.
    if (i + 1) % 5 == 0:
        best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
        print(f"Iteration {i+1}: Best state so far = {best_interim_state}")

best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")
```

### Using Multiple LLMs (and Beyond)

TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:

```python
from functools import partial

import treequest as tq

def generate(llm_name: str, parent_state=None):
    """
    Call LLM API using litellm, vllm, etc., to generate a new node
    """
    ...
    return new_state, new_score

llm_names = ["o4-mini", "gemini-2.5-pro"]
# Create dict of different actions backed by different LLMs.
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}

algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
    search_tree = algo.step(search_tree, generate_fns)
```
The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in `generate_fns`.

### Batch Semantics and Concurrency
- Algorithms are stateless objects; the evolving tree/search state is returned from `init_tree`, `step`, `ask`, and `tell`.
- `ask_batch(state, batch_size, actions)` returns exactly `batch_size` Trial objects to expand next.
  - Non-queue algorithms (e.g., `ABMCTSM`, `ABMCTSA`, `MultiArmedBanditUCB`) return exactly `batch_size` Trials.
  - Queue-based algorithms (e.g., `StandardMCTS`, `BestFirstSearchAlgo`, `TreeOfThoughtsBFS`) precompute a set of parent/action pairs and duplicate them if needed to fill `batch_size`.
- `tell(state, trial_id, (new_state, score))` reflects the result for the corresponding Trial.
  - Order-independent: you can call `tell` in any order; reflection is tied to `trial_id`.
  - Idempotent: calling `tell` twice on the same `trial_id` does not add extra nodes.
  - For queue-based algorithms, over-told Trials beyond possible number of childs from a parent node (e.g., `(# actions)*samples_per_action` for StandardMCTS) become `INVALID` and are not reflected.
- Scores are expected to be normalized to the `[0, 1]` range.
  - `ABMCTSM` exposes an internal `_advanced_config` for advanced knobs (e.g., bypassing the runtime check or tuning Bayesian priors). In most cases, the defaults are sufficient; only change it if you understand the impact.

## Algorithms

### ABMCTS-A: ABMCTS with Node Aggregation

ABMCTS-A uses node aggregation for adaptive branching:

```python
import treequest as tq

# Instantiate the ABMCTS-A algorithm.
ab_mcts_a = tq.ABMCTSA()

search_tree = ab_mcts_a.init_tree()
for _ in range(50):
    search_tree = ab_mcts_a.step(search_tree, generate_fns)
```

### ABMCTS-M: ABMCTS with Mixed Models

ABMCTS-M leverages PyMC's mixed modeling capabilities:

```python
import treequest as tq

# Instantiate the ABMCTS-M algorithm.
ab_mcts_m = tq.ABMCTSM()

search_tree = ab_mcts_m.init_tree()
for _ in range(30):
    search_tree = ab_mcts_m.step(search_tree, generate_fns)
```

**NOTE**: To run AB-MCTS-M, you need to install extra dependencies with the `treequest[abmcts-m]` or `treequest[all]` option.

## Visualization
TreeQuest provides visualization utilities to render the search tree. You can visualize the search tree using the `tq.render` function as shown in the Quick Start section.
You need to install optional dependencies for visualization either by `uv add treequest[vis]` or `uv add treequest[all]`.

```python
import html
import base64
from io import BytesIO
from pathlib import Path

import treequest as tq
from PIL import Image

class State:
    text: str
    image: Image.Image

def state_formatter_html(state: State) -> str:
    """Formats the state for HTML visualization."""
    buffer = BytesIO()
    state.image.save(buffer, format="WEBP")
    img_str = base64.b64encode(buffer.getvalue()).decode()
    # HTML representation is allowed (ensure safety when using untrusted content)
    # In case state.text cannot be trusted, we recommend e.g., html.escape(state.text) to avoid html injection.
    return f'<p>{state.text}</p><br/><img src="data:image/webp;base64,{img_str}" width=100% />'

generate_fns = {
    "action_1": ...  # Define your node generation functions (actions) here
}

algo = tq.ABMCTSA()
search_tree = algo.init_tree()
output_dir = Path("search_tree"); output_dir.mkdir(exist_ok=True)
for i in range(10):
    search_tree = algo.step(search_tree, generate_fns)
    # Rendering "latest.html" for checking search progress on the fly
    if i % 2 == 0 or i == 9:
        tq.render(
            search_tree,
            output_basename=output_dir / "latest",
            format="html",
            state_formatter=state_formatter_html, # Use html custom formetter
        )

tq.render(search_tree, output_basename=output_dir / "final", format="pdf")
tq.render(search_tree, output_basename=output_dir / "final", format="html", state_formatter=state_formatter_html)
```

> **IMPORTANT:** When using HTML format, ensure that the HTML file is securely handled, especially if the state formatter includes raw HTML content. Avoid opening untrusted HTML files in your browser. For example, XSS (cross site scripting) attacks can occur if the state includes malicious HTML/JavaScript code.

## Requirements

- Python 3.11+

## Contributing

Contributions are welcome! Please see [CONTRIBUTING.md](./CONTRIBUTING.md) for development tips.

## Citation
```bibtex
@inproceedings{inoue2025wider,
  title={Wider or Deeper?  Scaling {LLM} Inference-Time Compute with Adaptive Branching Tree Search},
  author={Yuichi Inoue and Kou Misaki and Yuki Imajuku and So Kuroki and Taishi Nakamura and Takuya Akiba},
  booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
  year={2025},
  url={https://openreview.net/forum?id=jAsr5GHt3P}
}
```

## License

[Apache 2.0](./LICENSE)

### Third-party notices

TreeQuest bundles a few assets whose original authors retain copyright:

- **D3.js v7** by Mike Bostock and contributors, distributed under the ISC License. The full license text is included in `src/treequest/vis/assets/d3.LICENSE.txt` alongside the `d3.v7.min.js` binary.
- **Colormap samples** extracted from matplotlib and seaborn. The data and license references are embedded in `src/treequest/vis/assets/colormaps.json`.

These notices must be preserved in any redistribution of TreeQuest, including any compiled artifacts.
