────────────────────────────────────────────────────────────────────────────────
# /Users/adamkarvonen/repos/token-difr/tests/test_fireworks.py  (1/3)
────────────────────────────────────────────────────────────────────────────────
"""End-to-end tests for token-difr using Fireworks API backend."""

import asyncio
import os
from dataclasses import dataclass

import openai
import pytest
from openai import OpenAI
from tqdm import tqdm
from transformers import AutoTokenizer

from token_difr import TokenSequence, compute_metrics_summary, encode_thinking_response, verify_outputs_api
from token_difr.openrouter_api import openrouter_request

# Load .env file if present
try:
    from dotenv import load_dotenv

    load_dotenv()
except ImportError:
    pass

# Fireworks model name and corresponding HuggingFace model for tokenizer
FIREWORKS_MODEL = "accounts/fireworks/models/llama-v3p3-70b-instruct"
HF_MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"

# OpenRouter configuration
OPENROUTER_LLAMA_MODEL = "meta-llama/llama-3.3-70b-instruct"
OPENROUTER_KIMI_MODEL = "moonshotai/kimi-k2-thinking"
HF_KIMI_MODEL = "moonshotai/Kimi-K2-Thinking"
FIREWORKS_KIMI_MODEL = "fireworks/kimi-k2-thinking"

TEST_PROMPTS = [
    "What is the capital of France?",
    "Explain photosynthesis in simple terms.",
    "Write a haiku about the ocean.",
    "What is 2 + 2?",
    "List three primary colors.",
    "Describe the water cycle.",
    "What causes rainbows?",
    "Explain gravity to a child.",
]


def get_fireworks_api_key():
    """Get Fireworks API key from environment."""
    api_key = os.environ.get("FIREWORKS_API_KEY")
    if not api_key:
        pytest.skip("FIREWORKS_API_KEY environment variable not set")
    return api_key


def generate_outputs_fireworks(
    client: OpenAI,
    tokenizer,
    prompts: list[str],
    temperature: float,
    max_tokens: int = 100,
    model: str = FIREWORKS_MODEL,
    verbose: bool = True,
) -> tuple[list[TokenSequence], int]:
    """Generate outputs using Fireworks API for testing.

    Returns:
        Tuple of (outputs, vocab_size) where outputs is a list of TokenSequence
        and vocab_size is derived from the tokenizer.
    """
    vocab_size = len(tokenizer)

    outputs = []
    iterator = tqdm(prompts, desc="Generating via Fireworks") if verbose else prompts
    for prompt in iterator:
        messages = [{"role": "user", "content": prompt}]
        rendered = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        prompt_token_ids = tokenizer.encode(rendered, add_special_tokens=False)

        # Generate using Fireworks completions API
        response = client.completions.create(
            model=model,
            prompt=rendered,
            max_tokens=max_tokens,
            temperature=temperature,
        )

        generated_text = response.choices[0].text
        generated_token_ids = tokenizer.encode(generated_text, add_special_tokens=False)

        outputs.append(
            TokenSequence(
                prompt_token_ids=prompt_token_ids,
                output_token_ids=generated_token_ids,
            )
        )

    return outputs, vocab_size


@pytest.mark.parametrize("temperature", [0.0])
def test_verify_outputs_fireworks(temperature):
    """Test Fireworks verification achieves >= 98% exact match and all metrics/summary fields are valid."""
    api_key = get_fireworks_api_key()

    top_k = 5  # Fireworks default
    top_p = 0.95
    seed = 42
    max_tokens = 100
    min_match_rate = 0.98

    # Create Fireworks client
    client = OpenAI(
        api_key=api_key,
        base_url="https://api.fireworks.ai/inference/v1",
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME, trust_remote_code=True)
    vocab_size = len(tokenizer)

    # Generate outputs
    outputs, vocab_size = generate_outputs_fireworks(
        client=client,
        tokenizer=tokenizer,
        prompts=TEST_PROMPTS,
        temperature=1e-8,  # Near-greedy for generation
        max_tokens=max_tokens,
    )

    total_tokens = sum(len(o.output_token_ids) for o in outputs)
    assert total_tokens > 0, "Should generate at least some tokens"

    # Verify outputs using Fireworks backend
    results = verify_outputs_api(
        outputs,
        backend="fireworks",
        vocab_size=vocab_size,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        seed=seed,
        client=client,
        model=FIREWORKS_MODEL,
        tokenizer=tokenizer,
        topk_logprobs=5,
    )

    # Check results structure
    assert len(results) == len(outputs), "Should have results for each output sequence"
    for seq_idx, seq_results in enumerate(results):
        assert len(seq_results) == len(outputs[seq_idx].output_token_ids), (
            f"Sequence {seq_idx}: should have metrics for each token"
        )

    # Check TokenMetrics fields
    for seq_results in results:
        for metrics in seq_results:
            assert isinstance(metrics.exact_match, bool)
            assert isinstance(metrics.prob, float)
            assert isinstance(metrics.margin, float)
            assert isinstance(metrics.logit_rank, (int, float))
            assert isinstance(metrics.gumbel_rank, (int, float))
            assert 0.0 <= metrics.prob <= 1.0, f"prob should be in [0, 1], got {metrics.prob}"
            assert metrics.logit_rank >= 0, f"logit_rank should be >= 0, got {metrics.logit_rank}"
            assert metrics.gumbel_rank >= 0, f"gumbel_rank should be >= 0, got {metrics.gumbel_rank}"

    # Check compute_metrics_summary
    summary = compute_metrics_summary(results)
    expected_keys = [
        "total_tokens",
        "exact_match_rate",
        "avg_prob",
        "avg_margin",
        "infinite_margin_rate",
        "avg_logit_rank",
        "avg_gumbel_rank",
    ]
    for key in expected_keys:
        assert key in summary, f"Missing key: {key}"
    assert summary["total_tokens"] == total_tokens
    assert 0.0 <= summary["exact_match_rate"] <= 1.0
    assert 0.0 <= summary["avg_prob"] <= 1.0
    assert 0.0 <= summary["infinite_margin_rate"] <= 1.0
    assert summary["avg_logit_rank"] >= 0.0
    assert summary["avg_gumbel_rank"] >= 0.0

    # Check match rate
    print(f"\nTemperature {temperature}: exact match rate = {summary['exact_match_rate']:.2%}")
    assert summary["exact_match_rate"] >= min_match_rate, (
        f"Temperature {temperature}: exact match rate {summary['exact_match_rate']:.2%} "
        f"is below {min_match_rate:.0%} threshold"
    )


# =============================================================================
# OpenRouter Generation Tests
# =============================================================================


def get_openrouter_api_key() -> str:
    """Get OpenRouter API key from environment."""
    api_key = os.environ.get("OPENROUTER_API_KEY")
    if not api_key:
        pytest.skip("OPENROUTER_API_KEY environment variable not set")
    return api_key


async def generate_outputs_openrouter(
    client: openai.AsyncOpenAI,
    tokenizer,
    prompts: list[str],
    model: str,
    provider: str,
    temperature: float = 0.0,
    max_tokens: int = 100,
    concurrency: int = 8,
) -> tuple[list[TokenSequence], int]:
    """Generate outputs using OpenRouter API.

    Uses openrouter_request from token_difr.openrouter_api.

    Returns:
        Tuple of (outputs, vocab_size).
    """
    vocab_size = len(tokenizer)
    semaphore = asyncio.Semaphore(concurrency)
    conversations = [[{"role": "user", "content": p}] for p in prompts]

    async def _wrapped(idx: int, messages: list[dict[str, str]]) -> tuple[int, str, str]:
        async with semaphore:
            content, reasoning = await openrouter_request(
                client, model, messages, max_tokens, temperature, provider
            )
            return idx, content, reasoning

    tasks = [asyncio.create_task(_wrapped(i, conv)) for i, conv in enumerate(conversations)]
    results: list[tuple[str, str]] = [("", "")] * len(conversations)

    for fut in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"OpenRouter ({provider})"):
        idx, content, reasoning = await fut
        results[idx] = (content, reasoning)

    # Convert to TokenSequence
    outputs = []
    for conv, (content, reasoning) in zip(conversations, results):
        rendered = tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
        prompt_token_ids = tokenizer.encode(rendered, add_special_tokens=False)
        response_token_ids = encode_thinking_response(content, reasoning, tokenizer, max_tokens)
        outputs.append(TokenSequence(prompt_token_ids=prompt_token_ids, output_token_ids=response_token_ids))

    return outputs, vocab_size


@pytest.mark.asyncio
async def test_verify_openrouter_generation_with_fireworks():
    """Test: Generate via OpenRouter (Fireworks provider), verify via Fireworks API."""
    fireworks_api_key = get_fireworks_api_key()
    openrouter_api_key = get_openrouter_api_key()

    top_k = 5
    top_p = 0.95
    seed = 42
    max_tokens = 100
    min_match_rate = 0.98

    # Create clients
    fireworks_client = OpenAI(
        api_key=fireworks_api_key,
        base_url="https://api.fireworks.ai/inference/v1",
    )
    openrouter_client = openai.AsyncOpenAI(
        api_key=openrouter_api_key,
        base_url="https://openrouter.ai/api/v1",
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME, trust_remote_code=True)
    vocab_size = len(tokenizer)

    # Generate via OpenRouter with Fireworks as provider
    outputs, vocab_size = await generate_outputs_openrouter(
        client=openrouter_client,
        tokenizer=tokenizer,
        prompts=TEST_PROMPTS,
        model=OPENROUTER_LLAMA_MODEL,
        provider="Fireworks",
        temperature=0.0,
        max_tokens=max_tokens,
    )

    total_tokens = sum(len(o.output_token_ids) for o in outputs)
    assert total_tokens > 0, "Should generate at least some tokens"

    # Verify via Fireworks
    results = verify_outputs_api(
        outputs,
        backend="fireworks",
        vocab_size=vocab_size,
        temperature=0.0,
        top_k=top_k,
        top_p=top_p,
        seed=seed,
        client=fireworks_client,
        model=FIREWORKS_MODEL,
        tokenizer=tokenizer,
        topk_logprobs=5,
    )

    summary = compute_metrics_summary(results)
    print(f"\nOpenRouter->Fireworks: exact match rate = {summary['exact_match_rate']:.2%}")

    assert summary["exact_match_rate"] >= min_match_rate, (
        f"Exact match rate {summary['exact_match_rate']:.2%} is below {min_match_rate:.0%} threshold"
    )


# =============================================================================
# Echo Tests - Verify Fireworks doesn't modify formatted prompts
# =============================================================================


@dataclass
class EchoModelConfig:
    """Configuration for echo testing a model."""

    fireworks_model: str
    hf_model: str
    skip_first_token: bool  # Whether Fireworks prepends an empty token
    is_thinking_model: bool  # Whether the model uses <think>...</think> tags


# Model configurations for echo tests
ECHO_MODEL_CONFIGS = {
    "llama": EchoModelConfig(
        fireworks_model=FIREWORKS_MODEL,
        hf_model=HF_MODEL_NAME,
        skip_first_token=True,  # Llama has BOS token rendered as empty
        is_thinking_model=False,
    ),
    "kimi": EchoModelConfig(
        fireworks_model=FIREWORKS_KIMI_MODEL,
        hf_model=HF_KIMI_MODEL,
        skip_first_token=False,  # Kimi doesn't have this issue
        is_thinking_model=True,
    ),
}


def get_echo_test_messages(is_thinking_model: bool) -> list[dict[str, str]]:
    """Get test messages appropriate for the model type."""
    if is_thinking_model:
        return [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "What is 2+2?"},
            {"role": "assistant", "content": "<think>Thinking...</think>The answer is 4."},
        ]
    else:
        return [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "Hello!"},
            {"role": "assistant", "content": "How are you?"},
        ]


@pytest.mark.parametrize("model_key", ["llama", "kimi"])
def test_echo(model_key: str):
    """Test that Fireworks echoes prompts without modification.

    Verifies:
    1. The number of echoed tokens matches expected (no hidden prefix/suffix)
    2. The non-special tokens match what we expect
    """
    config = ECHO_MODEL_CONFIGS[model_key]
    verbose = True

    api_key = get_fireworks_api_key()

    client = OpenAI(
        api_key=api_key,
        base_url="https://api.fireworks.ai/inference/v1",
    )

    try:
        tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
    except Exception as e:
        pytest.skip(f"Could not load {config.hf_model} tokenizer: {e}")

    # Get appropriate test messages for this model type
    messages = get_echo_test_messages(config.is_thinking_model)
    formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    expected_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)

    # Send with echo=True to get the prompt tokens back
    try:
        response = client.completions.create(
            model=config.fireworks_model,
            prompt=formatted_prompt,
            max_tokens=1,
            temperature=0.0,
            echo=True,
            logprobs=5,
        )
    except Exception as e:
        pytest.skip(f"{model_key} model not available on Fireworks: {e}")

    # The echoed tokens should include our prompt
    echoed_tokens = response.choices[0].logprobs.tokens

    # Some models (like Llama) have Fireworks prepend an empty token for BOS
    if config.skip_first_token:
        echoed_tokens = echoed_tokens[1:]

    # Exclude the generated token at the end
    prompt_tokens = echoed_tokens[: len(expected_tokens)]

    print(f"\n[{model_key}] Expected {len(expected_tokens)} tokens, got {len(prompt_tokens)} echoed tokens")

    # Key check: token count should match (no hidden prefix/suffix added)
    assert len(prompt_tokens) == len(expected_tokens), (
        f"[{model_key}] Token count mismatch: expected {len(expected_tokens)}, got {len(prompt_tokens)}. "
        "Fireworks may be adding hidden tokens to the prompt."
    )

    # Convert expected token IDs to strings for comparison
    expected_token_strings = [tokenizer.decode([tid]) for tid in expected_tokens]

    # Check for mismatches (special tokens may render differently, which is OK)
    mismatches = []
    special_token_mismatches = 0
    for i, (expected, echoed) in enumerate(zip(expected_token_strings, prompt_tokens)):
        if verbose:
            print(f"Expected: {expected}, Echoed: {echoed}")
        if expected != echoed:
            # Check if this is a special token rendering difference
            is_special = (echoed == "" and expected.startswith("<") and expected.endswith(">")) or (
                expected.startswith("<|") and expected.endswith("|>")
            )
            if is_special:
                special_token_mismatches += 1
            else:
                mismatches.append(f"  Position {i}: expected {repr(expected)}, got {repr(echoed)}")

    if mismatches:
        print(f"\nNon-special token mismatches ({len(mismatches)}):")
        for m in mismatches[:10]:
            print(m)

    if special_token_mismatches:
        print(f"Special token rendering differences: {special_token_mismatches} (expected)")

    # Only fail on non-special token mismatches
    assert len(mismatches) == 0, f"Found {len(mismatches)} unexpected token mismatches"
    print(f"\n[{model_key}] Echo test: PASSED (token counts match, no unexpected modifications)")

────────────────────────────────────────────────────────────────────────────────
# /Users/adamkarvonen/repos/token-difr/src/token_difr/api.py  (2/3)
────────────────────────────────────────────────────────────────────────────────
"""API-based verification using Tinker and Fireworks backends."""

from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Iterator, Literal

import torch
from openai import OpenAI
from openai.types.completion_choice import Logprobs
from tqdm import tqdm

from token_difr.common import (
    TokenMetrics,
    TokenSequence,
    _as_list,
    compute_metrics_summary,
)

# Type alias for a single position's top-k logprobs: list of (token_id, logprob) tuples
PositionLogprobs = list[tuple[int, float]] | None


@dataclass
class SparseLogprobs:
    """Compact representation of top-k logprobs for a sequence.

    This stores only the top-k (token_id, logprob) pairs per position,
    avoiding the memory cost of full vocab-size tensors.

    Attributes:
        index: Index of this sequence in the original outputs list.
        gen_ids: The generated token IDs being verified.
        logprobs: Per-position sparse logprobs. Each entry is either None
            or a list of (token_id, logprob) tuples for the top-k tokens.
    """

    index: int
    gen_ids: list[int]
    logprobs: list[PositionLogprobs]


def _sparse_logprobs_to_tensor(
    sparse_logprobs: list[PositionLogprobs],
    n_tokens: int,
    device: torch.device,
    vocab_size: int,
) -> torch.Tensor:
    """Convert sparse logprobs into a dense full-vocabulary tensor.

    Args:
        sparse_logprobs: Per-position sparse logprobs. Each entry is either None
            or a list of (token_id, logprob) tuples for the top-k tokens.
        n_tokens: Expected number of tokens.
        device: Torch device.
        vocab_size: Vocabulary size for the dense tensor.

    Returns:
        Tensor of shape (n_tokens, vocab_size) with logprobs filled in.
    """
    if len(sparse_logprobs) != n_tokens:
        raise ValueError(f"Expected {n_tokens} logprob rows, got {len(sparse_logprobs)}")

    logits = torch.full((n_tokens, vocab_size), float("-inf"), device=device)

    for j, row in enumerate(sparse_logprobs):
        if row is None:
            continue

        token_ids = torch.tensor([tok_id for tok_id, _ in row], device=device, dtype=torch.long)
        logprobs = torch.tensor([logprob for _, logprob in row], device=device)
        logits[j].scatter_(0, token_ids, logprobs)

    return logits


def _fetch_fireworks_logprobs(
    prompt_token_ids: list[int],
    gen_ids: list[int],
    client: OpenAI,
    model: str,
    tokenizer,
    topk_logprobs: int,
) -> Logprobs | None:
    """Fetch logprobs from Fireworks API (blocking call for use with ThreadPoolExecutor)."""
    prompt_text = tokenizer.decode(prompt_token_ids, skip_special_tokens=False)
    output_text = tokenizer.decode(gen_ids, skip_special_tokens=False)

    response = client.completions.create(
        model=model,
        prompt=prompt_text + output_text,
        max_tokens=1,
        logprobs=topk_logprobs,
        echo=True,
    )

    return response.choices[0].logprobs


def _iter_tinker_logprobs(
    request_data: list[tuple[int, list[int], list[int]]],
    sampling_client,
    topk_logprobs: int,
    verbose: bool,
) -> Iterator[SparseLogprobs]:
    """Submit Tinker requests and yield SparseLogprobs as results complete."""
    import tinker

    # Submit all requests (they return futures)
    futures = []
    for i, prompt_token_ids, gen_ids in request_data:
        full_sequence = prompt_token_ids + gen_ids
        full_prompt = tinker.ModelInput.from_ints(full_sequence)

        future = sampling_client.sample(
            prompt=full_prompt,
            sampling_params=tinker.SamplingParams(max_tokens=1),
            num_samples=1,
            include_prompt_logprobs=True,
            topk_prompt_logprobs=topk_logprobs,
        )
        futures.append((i, prompt_token_ids, gen_ids, future))

    # Yield results with progress bar
    iterator = tqdm(futures, desc="Verifying via tinker API") if verbose else futures
    for i, prompt_token_ids, gen_ids, future in iterator:
        logprob_result = future.result()
        prompt_len = len(prompt_token_ids)
        gen_len = len(gen_ids)

        # Extract just the slice we need (still sparse, no tensor conversion)
        sparse_logprobs = logprob_result.topk_prompt_logprobs[prompt_len : prompt_len + gen_len]
        yield SparseLogprobs(index=i, gen_ids=gen_ids, logprobs=sparse_logprobs)


def _fireworks_to_sparse_logprobs(
    top_logprobs: list[dict[str, float] | None],
    start_idx: int,
    n_tokens: int,
    tokenizer,
) -> list[PositionLogprobs]:
    """Convert Fireworks logprobs to sparse format matching Tinker's output.

    Returns per-position sparse logprobs where each entry is either None
    or a list of (token_id, logprob) tuples.
    """
    slice_rows = top_logprobs[start_idx : start_idx + n_tokens]
    if len(slice_rows) != n_tokens:
        raise ValueError(f"Expected {n_tokens} logprob rows, got {len(slice_rows)}")

    result: list[PositionLogprobs] = []
    for row in slice_rows:
        if row is None:
            result.append(None)
            continue

        token_logprobs: list[tuple[int, float]] = []
        for token_str, logprob in row.items():
            token_ids = tokenizer.encode(token_str, add_special_tokens=False)
            if len(token_ids) == 1:
                token_logprobs.append((token_ids[0], logprob))
        result.append(token_logprobs if token_logprobs else None)

    return result


def _iter_fireworks_logprobs(
    request_data: list[tuple[int, list[int], list[int]]],
    client: OpenAI,
    model: str,
    tokenizer,
    topk_logprobs: int,
    verbose: bool,
) -> Iterator[SparseLogprobs]:
    """Submit Fireworks requests and yield SparseLogprobs as results complete."""
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        for i, prompt_token_ids, gen_ids in request_data:
            future = executor.submit(
                _fetch_fireworks_logprobs,
                prompt_token_ids=prompt_token_ids,
                gen_ids=gen_ids,
                client=client,
                model=model,
                tokenizer=tokenizer,
                topk_logprobs=topk_logprobs,
            )
            futures.append((i, prompt_token_ids, gen_ids, future))

        # Yield results with progress bar
        iterator = tqdm(futures, desc="Verifying via fireworks API") if verbose else futures
        for i, prompt_token_ids, gen_ids, future in iterator:
            logprobs_data = future.result()
            prompt_len = len(prompt_token_ids)
            gen_len = len(gen_ids)

            # +1 for the empty token Fireworks prepends
            start_idx = prompt_len + 1

            sparse_logprobs = _fireworks_to_sparse_logprobs(
                logprobs_data.top_logprobs,
                start_idx=start_idx,
                n_tokens=gen_len,
                tokenizer=tokenizer,
            )
            yield SparseLogprobs(index=i, gen_ids=gen_ids, logprobs=sparse_logprobs)


def _compute_verification_metrics_from_logprobs(
    logprobs_JV: torch.Tensor,
    gen_ids: list[int],
    temperature: float,
    top_k: int,
    top_p: float,
    seed: int,
    device: torch.device,
) -> list[TokenMetrics]:
    """Compute verification metrics from log probabilities (not raw logits).

    Currently only supports greedy verification (temperature=0). The signature
    includes all sampling parameters for future expansion.

    For greedy verification:
    - exact_match: True if claimed token is the argmax of logprobs
    - prob: probability of the claimed token
    - margin: logprob(top1) - logprob(claimed_token), 0 if exact match
    - logit_rank: rank of claimed token by logprob (0 = highest)
    - gumbel_rank: same as logit_rank for greedy (placeholder for future)
    """
    # Keep parameters for future use
    _ = temperature, top_k, top_p, seed

    J = logprobs_JV.shape[0]
    gold_col_idx_J = torch.as_tensor(gen_ids, device=device, dtype=torch.long)

    # Convert log probs to probs
    probs_JV = torch.exp(logprobs_JV.float())

    row_idx_J = torch.arange(J, device=device)
    gold_logprobs_J = logprobs_JV[row_idx_J, gold_col_idx_J]

    # Compute rank based on log probs (higher logprob = better, so rank 0 = best)
    logit_ranks_J = (logprobs_JV > gold_logprobs_J.unsqueeze(1)).sum(dim=1).float()

    probs_gold_J = probs_JV.gather(1, gold_col_idx_J.view(-1, 1)).squeeze(1)

    # Greedy verification: predicted token is argmax of logprobs
    pred_ids_J = logprobs_JV.argmax(dim=-1)

    # Margin: logprob(top1) - logprob(claimed)
    max_logprobs_J = logprobs_JV.max(dim=-1).values
    margins_J = max_logprobs_J - gold_logprobs_J

    seq_token_metrics: list[TokenMetrics] = []
    for j in range(J):
        actual_id = int(gen_ids[j])
        token_metrics = TokenMetrics(
            exact_match=bool(int(pred_ids_J[j]) == actual_id),
            prob=float(probs_gold_J[j].item()),
            margin=float(margins_J[j].item()),
            logit_rank=float(logit_ranks_J[j].item()),
            gumbel_rank=float(logit_ranks_J[j].item()),  # Same as logit_rank for greedy
        )
        seq_token_metrics.append(token_metrics)

    return seq_token_metrics


@torch.inference_mode()
def verify_outputs_api(
    outputs: list[TokenSequence],
    backend: Literal["tinker", "fireworks"],
    vocab_size: int,
    *,
    temperature: float,
    top_k: int,
    top_p: float,
    seed: int,
    # Backend-specific args
    sampling_client=None,  # For Tinker
    client: OpenAI | None = None,  # For Fireworks
    model: str | None = None,  # For Fireworks
    tokenizer=None,  # For Fireworks
    topk_logprobs: int = 20,
    verbose: bool = True,
) -> list[list[TokenMetrics]]:
    """
    Verify LLM outputs using API-based verification.

    This function takes token sequences (prompt + generated output) and verifies
    whether the outputs could have been produced by the specified model using
    the given sampling parameters. Requests are submitted in parallel for efficiency.

    Args:
        outputs: List of TokenSequence objects containing prompt and output token IDs.
        backend: Which backend to use ("tinker" or "fireworks").
        vocab_size: The vocabulary size of the model.
        temperature: Sampling temperature used during generation. Required.
        top_k: Top-k sampling parameter. Required.
        top_p: Top-p (nucleus) sampling parameter. Required.
        seed: Random seed used during generation. Required.
        sampling_client: A Tinker SamplingClient (required for backend="tinker").
        client: An OpenAI client configured for Fireworks (required for backend="fireworks").
        model: The model name for Fireworks (required for backend="fireworks").
        tokenizer: HuggingFace tokenizer (required for backend="fireworks").
        topk_logprobs: Number of top logprobs to request. Default: 20.
        verbose: Whether to show progress and print a summary. Default: True.

    Returns:
        List of lists of TokenMetrics, one per token in each output sequence.
    """
    device = torch.device("cpu")

    all_token_metrics: list[list[TokenMetrics]] = [[] for _ in outputs]

    # Prepare request data
    request_data: list[tuple[int, list[int], list[int]]] = []  # (index, prompt_token_ids, gen_ids)
    for i, req in enumerate(outputs):
        prompt_token_ids: list[int] = _as_list(req.prompt_token_ids)
        gen_ids: list[int] = _as_list(req.output_token_ids)
        if len(gen_ids) == 0:
            continue
        request_data.append((i, prompt_token_ids, gen_ids))

    if len(request_data) == 0:
        return all_token_metrics

    # Get logprobs iterator based on backend
    # Both iterators yield SparseLogprobs - compact representation of top-k logprobs
    if backend == "tinker":
        if sampling_client is None:
            raise ValueError("sampling_client is required for tinker backend")
        logprobs_iter = _iter_tinker_logprobs(
            request_data=request_data,
            sampling_client=sampling_client,
            topk_logprobs=topk_logprobs,
            verbose=verbose,
        )
    elif backend == "fireworks":
        if client is None or model is None or tokenizer is None:
            raise ValueError("client, model, and tokenizer are required for fireworks backend")
        logprobs_iter = _iter_fireworks_logprobs(
            request_data=request_data,
            client=client,
            model=model,
            tokenizer=tokenizer,
            topk_logprobs=topk_logprobs,
            verbose=verbose,
        )
    else:
        raise ValueError(f"Unknown backend: {backend}")

    # Process all results with unified loop
    # Materialize tensor only when needed to minimize memory usage
    for result in logprobs_iter:
        logprobs_JV = _sparse_logprobs_to_tensor(
            sparse_logprobs=result.logprobs,
            n_tokens=len(result.gen_ids),
            device=device,
            vocab_size=vocab_size,
        )
        seq_token_metrics = _compute_verification_metrics_from_logprobs(
            logprobs_JV=logprobs_JV,
            gen_ids=result.gen_ids,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            seed=seed,
            device=device,
        )
        all_token_metrics[result.index] = seq_token_metrics

    if verbose:
        summary = compute_metrics_summary(all_token_metrics)
        print("Verification Summary:")
        print(f"  Total tokens: {summary['total_tokens']}")
        print(f"  Exact match rate: {summary['exact_match_rate']:.2%}")
        print(f"  Average probability: {summary['avg_prob']:.4f}")
        print(f"  Average margin: {summary['avg_margin']:.4f} ({summary['infinite_margin_rate']:.2%} infinite)")

    return all_token_metrics

────────────────────────────────────────────────────────────────────────────────
# /Users/adamkarvonen/repos/token-difr/src/token_difr/openrouter_api.py  (3/3)
────────────────────────────────────────────────────────────────────────────────
import asyncio
import json
from pathlib import Path

import openai
from transformers import AutoTokenizer
from tqdm import tqdm

from token_difr.common import construct_prompts, encode_thinking_response


def _sanitize(name: str) -> str:
    return name.replace("/", "_").replace(".", "_")


async def openrouter_request(
    client: openai.AsyncOpenAI,
    model: str,
    messages: list[dict[str, str]],
    max_tokens: int,
    temperature: float,
    provider: str,
) -> tuple[str, str]:
    """Make a single OpenRouter request.

    Returns:
        Tuple of (content, reasoning) where reasoning may be empty for non-thinking models.
    """
    completion = await client.chat.completions.create(
        model=model,
        messages=messages,  # type: ignore[arg-type]
        max_tokens=max_tokens,
        temperature=temperature,
        extra_body={
            "provider": {"order": [provider]},
        },
    )
    content = completion.choices[0].message.content or ""
    reasoning = getattr(completion.choices[0].message, "reasoning", None) or ""
    return content, reasoning


async def run_all_prompts(
    client: openai.AsyncOpenAI,
    model: str,
    prompts: list[list[dict[str, str]]],
    max_tokens: int,
    temperature: float,
    provider: str,
    concurrency: int = 8,
) -> list[tuple[list[dict[str, str]], str, str]]:
    """Generate responses for all prompts concurrently.

    Returns:
        List of (prompt, content, reasoning) tuples.
    """
    semaphore = asyncio.Semaphore(concurrency)

    async def _wrapped(idx: int, prompt: list[dict[str, str]]) -> tuple[int, str, str]:
        async with semaphore:
            content, reasoning = await openrouter_request(
                client,
                model,
                prompt,
                max_tokens,
                temperature=temperature,
                provider=provider,
            )
            return idx, content, reasoning

    tasks = [asyncio.create_task(_wrapped(i, p)) for i, p in enumerate(prompts)]
    results: list[tuple[list[dict[str, str]], str, str]] = [([], "", "") for _ in prompts]

    for fut in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"OpenRouter ({provider})"):
        idx, content, reasoning = await fut
        results[idx] = (prompts[idx], content, reasoning)

    return results


def save_results(
    samples: list[tuple[list[dict[str, str]], str, str]],
    save_path: Path,
    config: dict[str, object],
    model_name: str,
    max_tokens: int,
) -> None:
    """Save samples as JSON in VLLM-style format with tokenized prompts and responses.

    Args:
        samples: List of (prompt, content, reasoning) tuples.
        save_path: Path to save the JSON file.
        config: Configuration dictionary to include in output.
        model_name: HuggingFace model name for tokenizer.
        max_tokens: Maximum tokens for response truncation.
    """
    # Load tokenizer for the model
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Tokenize prompts and responses
    vllm_samples = []
    for prompt, content, reasoning in tqdm(samples, desc="Tokenizing samples"):
        # Tokenize prompt (conversation array)
        rendered_prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
        prompt_token_ids = tokenizer.encode(rendered_prompt, add_special_tokens=False, return_tensors=None)

        # Tokenize response using encode_thinking_response for proper handling of thinking models
        response_token_ids = encode_thinking_response(content, reasoning, tokenizer, max_tokens)

        # Create VLLM-style sample
        vllm_samples.append({"prompt_token_ids": prompt_token_ids, "outputs": [{"token_ids": response_token_ids}]})

    del tokenizer

    # Save as JSON
    payload = {"config": config, "samples": vllm_samples}
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(payload, f)
    print(f"Saved {len(samples)} samples to {save_path}")


async def main():
    model_name = "meta-llama/llama-3.1-8b-instruct"
    max_tokens = 500
    temperature = 0.0
    concurrency = 50

    for provider in ["cerebras", "hyperbolic", "groq", "siliconflow/fp8", "deepinfra"]:
        save_dir = Path("openrouter_responses")
        n_samples = 2000
        max_ctx_len = 512

        # Load OpenRouter API key from file to keep the script simple.
        with open("openrouter_api_key.txt") as f:
            api_key = f.read().strip()

        client = openai.AsyncOpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=api_key,
        )

        prompts = construct_prompts(
            n_prompts=n_samples, max_ctx_len=max_ctx_len, model_name=model_name
        )
        print(f"Loaded {len(prompts)} prompts from dataset.")

        responses = await run_all_prompts(
            client,
            model_name,
            prompts,
            max_tokens=max_tokens,
            temperature=temperature,
            provider=provider,
            concurrency=concurrency,
        )

        save_dir.mkdir(parents=True, exist_ok=True)
        model_tag = _sanitize(f"{provider}_{model_name}")
        save_filename = f"openrouter_{model_tag}_token_difr_prompts_test.json"
        config = {
            "model": model_name,
            "provider": provider,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "n_samples": n_samples,
            "max_ctx_len": max_ctx_len,
        }
        save_results(
            responses,
            save_dir / save_filename,
            config=config,
            model_name=model_name,
            max_tokens=max_tokens,
        )


if __name__ == "__main__":
    asyncio.run(main())

