===== ludic_envs/tictactoe.py =====
from __future__ import annotations
import random
import re
import math
import asyncio
from typing import Dict, List, Tuple, Optional, Type, Literal, Any
from enum import Enum
from functools import lru_cache

from pydantic import BaseModel, ValidationError
from openai import OpenAI

from ludic_envs.ludic_env import LudicEnv, StepResult, Rollout

def extract_tag_value(text: str, tag: str) -> Optional[str]:
    """Extracts the value from a given XML-like tag in a string."""
    match = re.search(f'<{tag}>(.*?)</{tag}>', text, re.DOTALL)
    return match.group(1).strip() if match else None

# --- Value Function Calculation ---

def _create_value_table() -> Dict[Tuple[Optional[str], ...], float]:
    """
    Creates a lookup table for the value of every possible Tic-Tac-Toe state.
    The value is from the perspective of the player whose turn it is.
    -  1: The current player can force a win.
    -  0: The game is a forced draw.
    - -1: The current player will lose to an optimal opponent.
    """
    memo = {}
    win_lines = (
        (0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6),
        (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)
    )

    def get_winner(board: Tuple[Optional[str], ...]) -> Optional[str]:
        for a, b, c in win_lines:
            if board[a] and board[a] == board[b] == board[c]:
                return board[a]
        return None

    def get_value(board: Tuple[Optional[str], ...], player: str) -> float:
        state = (board, player)
        if state in memo:
            return memo[state]

        winner = get_winner(board)
        if winner:
            return -1.0  # Previous player won, so current player loses

        if all(cell is not None for cell in board):
            return 0.0  # Draw

        opponent = 'O' if player == 'X' else 'X'
        possible_moves = [i for i, cell in enumerate(board) if cell is None]
        
        # Assume loss until a better move is found
        best_value = -1.0
        for move in possible_moves:
            new_board_list = list(board)
            new_board_list[move] = player
            new_board = tuple(new_board_list)
            
            # Value for us is the negative of the value for the opponent
            value = -get_value(new_board, opponent)
            if value > best_value:
                best_value = value
            # If a winning move is found, prune other branches
            if best_value == 1.0:
                break
        
        memo[state] = best_value
        return best_value

    # Populate the table for all reachable states
    initial_board = tuple([None] * 9)
    get_value(initial_board, 'X') # Start with 'X'
    get_value(initial_board, 'O') # and 'O'
    
    # We only need the board state as the key, since the player to move can be inferred
    final_table = {}
    for (board, player), value in memo.items():
        if board not in final_table:
             # The stored value is for 'player', which is the one to move
             final_table[board] = value
    
    return final_table


# --- Semantic Action Space (Specific to this environment) ---
class Height(str, Enum):
    TOP = "A"
    CENTER = "B"
    BOTTOM = "C"

class Width(str, Enum):
    LEFT = "1"
    CENTER = "2"
    RIGHT = "3"

class Action(BaseModel):
    """Validated player move, combining height and width."""
    height: Height
    width: Width

    @classmethod
    def from_str(cls, move_str: str) -> "Action":
        """Parses a string like 'A1' or 'C2' into an Action instance."""
        move_str = move_str.strip().upper()
        if len(move_str) != 2:
            raise ValueError("Move must be two characters, e.g., 'A1', 'B3'.")
        
        height_char, width_char = move_str[0], move_str[1]
        
        return cls(height=Height(height_char), width=Width(width_char))


# --- Main Environment Implementation ---
class TicTacToe(LudicEnv):
    """
    An implementation of Tic-Tac-Toe conforming to the LudicEnv interface.
    """
    VALUE_FUNCTION: Dict[Tuple[Optional[str], ...], float] = _create_value_table()
    SUGGESTED_SYSPROMPT = (
        "You are a tic-tac-toe player. You are playing as '{mark}'. "
        "Always choose the move most likely to lead to an eventual win. "
        "Return your move as an XML object with a single property 'move', "
        "like so: <move>A1</move>. Optional moves are 'A1', 'B2', 'C3', etc."
    )
    COORD_MAP: Dict[Tuple[Height, Width], int] = {
        (Height.TOP,    Width.LEFT):   0, (Height.TOP,    Width.CENTER): 1, (Height.TOP,    Width.RIGHT):  2,
        (Height.CENTER, Width.LEFT):   3, (Height.CENTER, Width.CENTER): 4, (Height.CENTER, Width.RIGHT):  5,
        (Height.BOTTOM, Width.LEFT):   6, (Height.BOTTOM, Width.CENTER): 7, (Height.BOTTOM, Width.RIGHT):  8,
    }
    INDEX_TO_NAME: Dict[int, str] = {i: f"{h.value}{w.value}" for (h, w), i in COORD_MAP.items()}
    WIN_LINES: Tuple[Tuple[int, int, int], ...] = (
        (0, 1, 2), (3, 4, 5), (6, 7, 8),
        (0, 3, 6), (1, 4, 7), (2, 5, 8),
        (0, 4, 8), (2, 4, 6)
    )

    def __init__(self, **kwargs) -> None:
        super().__init__()
        self.action_space = Action
        self.board: List[Optional[str]] = [None] * 9
        self.agent_mark: str = 'X'
        self.opponent_mark: str = 'O'
        self.done: bool = False
        self.move_number: int = 0
        self.system_prompt: str = ""
        
        self.show_opponent_move: bool = kwargs.get("show_opponent_move", False)
        self.last_opponent_move: Optional[str] = None
        self.reward_shaping: bool = kwargs.get("reward_shaping", False)

    def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        if seed is not None:
            random.seed(seed)
        self.board = [None] * 9
        self.done = False
        self.move_number = 0
        self.agent_mark, self.opponent_mark = random.choice([('X', 'O'), ('O', 'X')])
        self.system_prompt = self.SUGGESTED_SYSPROMPT.format(mark=self.agent_mark)
        self.last_opponent_move = None
        
        if self.agent_mark == 'O':
            opponent_move_idx = random.choice(self._empty_cells())
            self._place(opponent_move_idx, self.opponent_mark)
            self.last_opponent_move = self.INDEX_TO_NAME[opponent_move_idx]
            
        return {"observation": self._obs(), "system_prompt": self.system_prompt}, {}

    def step(self, action: str) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
        if self.done:
            raise RuntimeError("Game has ended. Call reset().")

        info = {"invalid_move": False}
        
        value_before = 0.0
        if self.reward_shaping:
            value_before = self.VALUE_FUNCTION.get(tuple(self.board), 0.0)

        try:
            parsed_action = Action.from_str(action)
            pos_idx = self.COORD_MAP[(parsed_action.height, parsed_action.width)]
            if self.board[pos_idx] is not None:
                raise ValueError(f"Cell {self.INDEX_TO_NAME[pos_idx]} is already occupied.")
        except (ValueError, ValidationError) as e:
            self.done = True
            reward = -5.0
            info.update({"invalid_move": True, "error": str(e)})
            obs = {"observation": f"Your action was invalid. Reason: {e}\nGame Over."}
            return obs, reward, self.done, info

        self._place(pos_idx, self.agent_mark)
        self.move_number += 1
        
        terminal_reward = self._terminal_reward()
        if terminal_reward is not None:
            self.done = True
            return {"observation": self._obs()}, terminal_reward, self.done, info

        reward = 0.0
        if self.reward_shaping:
            value_after_agent_move = -self.VALUE_FUNCTION.get(tuple(self.board), 0.0)
            reward = value_after_agent_move - value_before
            info["reward_shaping_details"] = f"V(s')={value_after_agent_move:.2f}, V(s)={value_before:.2f}"

        empty_cells = self._empty_cells()
        if empty_cells:
            opponent_move_idx = random.choice(empty_cells)
            self._place(opponent_move_idx, self.opponent_mark)
            self.last_opponent_move = self.INDEX_TO_NAME[opponent_move_idx]
        
        terminal_reward = self._terminal_reward()
        if terminal_reward is not None:
            self.done = True
            return {"observation": self._obs()}, terminal_reward, self.done, info

        return {"observation": self._obs()}, reward, self.done, info
        
    @classmethod
    async def run_rollouts(
        cls: Type[TicTacToe],
        env_args: Dict[str, Any],
        batch_size: int,
        group_size: int,
        max_steps: int,
        client: OpenAI,
        model_name: str,
        sampling_args: Dict[str, Any],
    ) -> Tuple[List[Rollout], Dict[str, float]]:
        """
        Runs rollouts in batches.

        Args:
            env_args: A dictionary of arguments for the environment constructor.
                    Includes an optional 'prune_on_illegal_move' (bool) key.
                    If True, rollouts with an illegal move are pruned to a single step
                    after statistics are calculated.
        """
        strategy = env_args.get("strategy", "full_history")
        prune_on_illegal = env_args.get("prune_on_illegal_move", False)

        import copy

        num_unique_starts = batch_size
        
        master_envs = [cls(**env_args) for _ in range(num_unique_starts)]
        unique_initial_states = [env.reset() for env in master_envs]

        envs: List[TicTacToe] = []
        initial_states: List[Dict[str, Any]] = []
        for i in range(num_unique_starts):
            master_env = master_envs[i]
            initial_state_dict = unique_initial_states[i]
            for _ in range(group_size):
                envs.append(copy.deepcopy(master_env))
                initial_states.append(initial_state_dict)
        
        total_rollouts = len(envs)
        wins = 0
        losses = 0
        draws = 0
        invalid_moves = 0
        total_game_steps = 0

        trajectories: List[List[StepResult]] = [[] for _ in range(total_rollouts)]
        histories: List[List[Dict[str, str]]] = [[] for _ in range(total_rollouts)]
        observations = [state[0]['observation'] for state in initial_states]
        system_prompts = [state[0]['system_prompt'] for state in initial_states]
        active_indices = list(range(total_rollouts))

        if strategy == "full_history":
            for i in active_indices:
                histories[i].append({"role": "system", "content": system_prompts[i]})
                histories[i].append({"role": "user", "content": observations[i]})

        for _ in range(max_steps):
            if not active_indices: break

            if strategy == "full_history":
                prompts_to_send = [histories[i] for i in active_indices]
            else:
                prompts_to_send = [
                    [
                        {"role": "system", "content": system_prompts[i]},
                        {"role": "user", "content": observations[i]}
                    ] for i in active_indices
                ]

            llm_tasks = [
                asyncio.to_thread(
                    client.chat.completions.create,
                    model=model_name, messages=prompt, **sampling_args
                ) for prompt in prompts_to_send
            ]
            responses = await asyncio.gather(*llm_tasks)

            next_active_indices = []
            for i, response in enumerate(responses):
                env_idx = active_indices[i]
                env = envs[env_idx]
                prompt_for_this_step = prompts_to_send[i]
                raw_response = response.choices[0].message.content or ""
                parsed_action = extract_tag_value(raw_response, "move") or ""
                obs_dict, reward, done, info = env.step(parsed_action)
                next_obs = obs_dict['observation']
                
                is_invalid = info.get("invalid_move", False)

                step_data = StepResult(
                    prompt=prompt_for_this_step,
                    observation=observations[env_idx],
                    response=raw_response,
                    parsed_action=parsed_action,
                    reward=reward,
                    reward_dict={"game_outcome": reward},
                    done=done,
                    info=info
                )
                
                # Always append the step to maintain full history for stat calculation
                trajectories[env_idx].append(step_data)
                
                if strategy == "full_history":
                    histories[env_idx].append({"role": "assistant", "content": raw_response})
                    if not done:
                        histories[env_idx].append({"role": "user", "content": next_obs})

                if not done:
                    next_active_indices.append(env_idx)
                    observations[env_idx] = next_obs
                else:
                    # --- Tally results using the FULL trajectory length ---
                    total_game_steps += len(trajectories[env_idx])
                    if is_invalid:
                        invalid_moves += 1
                    elif reward == 1.0:
                        wins += 1
                    elif reward == 0.0:
                        losses += 1
                    elif reward == 0.5:
                        draws += 1
                    
                    # --- AFTER tallying, prune the trajectory for the final output if needed ---
                    if is_invalid and prune_on_illegal:
                        # Replace the full trajectory with only its final, illegal step
                        trajectories[env_idx] = [step_data]


            active_indices = next_active_indices

        finished_games = wins + losses + draws + invalid_moves
        env_stats = {
            "win_rate": wins / total_rollouts if total_rollouts > 0 else 0,
            "loss_rate": losses / total_rollouts if total_rollouts > 0 else 0,
            "draw_rate": draws / total_rollouts if total_rollouts > 0 else 0,
            "invalid_move_rate": invalid_moves / total_rollouts if total_rollouts > 0 else 0,
            "avg_game_length": total_game_steps / finished_games if finished_games > 0 else 0,
        }

        validated_rollouts = [Rollout.model_validate(traj) for traj in trajectories]
        return validated_rollouts, env_stats

    # --- Helper methods ---
    def _terminal_reward(self) -> Optional[float]:
        """Calculates terminal reward. Returns 1.0 for win, 0.0 for loss, 0.5 for draw."""
        for a, b, c in self.WIN_LINES:
            if self.board[a] and self.board[a] == self.board[b] == self.board[c]:
                return 1.0 if self.board[a] == self.agent_mark else 0.0
        return 0.5 if all(cell is not None for cell in self.board) else None
    
    def _obs(self) -> str:
        """Generates the string representation of the current game state."""
        b = [self.board[i] or " " for i in range(9)]
        board_str = (f"  1 2 3\nA {b[0]}|{b[1]}|{b[2]}\n  -+-+-\nB {b[3]}|{b[4]}|{b[5]}\n  -+-+-\nC {b[6]}|{b[7]}|{b[8]}")
        
        if self.show_opponent_move and self.last_opponent_move:
            return f"Opponent ({self.opponent_mark}) played at {self.last_opponent_move}.\n\n{board_str}"
            
        return board_str

    def _place(self, pos: int, mark: str) -> None:
        """Places a mark on the board at the given position."""
        self.board[pos] = mark

    def _empty_cells(self) -> List[int]:
        """Returns a list of indices for all empty cells on the board."""
        return [i for i, v in enumerate(self.board) if v is None]
===== ludic_envs/trainer.py =====
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py

import logging
from collections import defaultdict, deque
from contextlib import nullcontext
from typing import Optional, Union, Any, List, Dict, Tuple, Sized

import asyncio
import datasets
import openai
import torch
import deepspeed
from torch.utils.data import DataLoader
from accelerate.utils import broadcast_object_list, gather_object, is_peft_model
from peft import PeftConfig, get_peft_model
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, TrainerState, TrainingArguments, TrainerControl
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import seed_worker
from trl.models import create_reference_model, prepare_deepspeed
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.utils import (
    disable_dropout_in_model,
    pad,
    selective_log_softmax
)
import wandb
import numpy as np

from datasets import Dataset

from ludic_envs.ludic_env import LudicEnv
from ludic_envs.ludic_batch_generator import LudicAsyncBatchGenerator, LudicBatchRequest
from ludic_envs.ludic_train_config import LudicConfig

# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`):
            Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`:
            Standard deviation of the tensor, ignoring NaNs.
    """
    variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2)  # Compute variance ignoring NaNs
    count = torch.sum(~torch.isnan(tensor))  # Count of non-NaN values
    variance *= count / (count - 1)  # Bessel's correction
    return torch.sqrt(variance)

def split_tensor_dict(
    tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int
) -> list[dict[str, Optional[torch.Tensor]]]:
    """
    Splits a dictionary of tensors along the first dimension into `num_chunks` parts.

    This function uses `torch.tensor_split` to handle cases where the number of samples
    is not evenly divisible by `num_chunks`, ensuring no data is lost. The samples
    are distributed as evenly as possible across the chunks.

    Example:
        >>> x = torch.arange(14).reshape(7, 2)
        >>> y = torch.arange(7).reshape(7, 1)
        >>> tensor_dict = {"x": x, "y": y}
        >>> split_tensor_dict(tensor_dict, 3)
        [
            {'x': tensor([[0, 1], [2, 3], [4, 5]]), 'y': tensor([[0], [1], [2]])},
            {'x': tensor([[6, 7], [8, 9]]), 'y': tensor([[3], [4]])},
            {'x': tensor([[10, 11], [12, 13]]), 'y': tensor([[5], [6]])}
        ]
    """
    # Find the total number of samples from the first available tensor.
    first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
    total_samples = first_tensor.shape[0]

    # If the number of samples is less than the number of chunks, we cannot split.
    # This might happen with very small batches. Here, we'll return a list containing
    # the original dict and (num_chunks-1) empty dicts to avoid crashing the trainer logic
    # that expects a list of a certain length. A warning could be useful too.
    if total_samples < num_chunks:
        # This is an edge case. The most robust way to handle it is to return one chunk
        # with all the data and the rest empty.
        result = [tensor_dict]
        empty_chunk = {key: None for key in tensor_dict}
        result.extend([empty_chunk] * (num_chunks - 1))
        # Consider adding a logger.warning here.
        return result


    # Split each tensor in the dictionary into `num_chunks`.
    # `torch.tensor_split` will handle the uneven distribution automatically.
    split_tensors = {
        key: torch.tensor_split(tensor, num_chunks) if tensor is not None else [None] * num_chunks
        for key, tensor in tensor_dict.items()
    }

    # Reorganize the list of chunks for each key into a list of dictionaries.
    # This creates the final list of chunks, where each chunk is a dictionary.
    return [
        {key: chunks[i] for key, chunks in split_tensors.items()}
        for i in range(num_chunks)
    ]

def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]:
    """
    Shuffles a dictionary of tensors along the first dimension in unison.

    Example:
        >>> x = torch.arange(6).reshape(3, 2)
        >>> y = torch.arange(3).reshape(3, 1)
        >>> tensor_dict = {"x": x, "y": y}
        >>> shuffle_tensor_dict(tensor_dict)
        {'x': tensor([[2, 3],
                      [0, 1],
                      [4, 5]]),
         'y': tensor([[1],
                      [0],
                      [2]])}
    """
    first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
    batch_size = first_tensor.shape[0]
    permutation = torch.randperm(batch_size)
    return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()}

def nanmin(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`): Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
    """
    if torch.isnan(tensor).all():
        return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
    return torch.min(tensor[~torch.isnan(tensor)])

def nanmax(tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.

    Args:
        tensor (`torch.Tensor`): Input tensor of shape `(N,)`.

    Returns:
        `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
    """
    if torch.isnan(tensor).all():
        return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
    return torch.max(tensor[~torch.isnan(tensor)])

class SyncVLLMCallback(TrainerCallback):
    """
    This callback safely triggers the vLLM weight sync at the end of a step.
    """

    __slots__ = ("trainer",)

    def __init__(self, trainer):
        self.trainer = trainer

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        trainer = self.trainer
        trainer.logger.info(f"Callback: Syncing weights to vLLM before global_step {state.global_step}")
        trainer._move_model_to_vllm()

class LudicTrainer(Trainer):
    def __init__(
            self,
            model: PreTrainedModel,
            env: LudicEnv,
            args: LudicConfig,
            processing_class: PreTrainedTokenizerBase,
            callbacks: Optional[list[TrainerCallback]] = None,
            optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
            peft_config: Optional[PeftConfig] = None,
            **kwargs,
    ): 
        self.logger = logging.getLogger(__name__)
        
        # Models
        if peft_config is not None:
            model = get_peft_model(model, peft_config) # type: ignore

        # Enable gradient checkpointing if requested
        if args.gradient_checkpointing:
            model = self._enable_gradient_checkpointing(model, args) # type: ignore
        
        # Suppress irrelevant warning
        model.warnings_issued["estimate_tokens"] = True 

        # Tokenizer pad token
        if processing_class.pad_token is None: # type: ignore
            processing_class.pad_token = processing_class.eos_token # type: ignore

        # Training arguments
        self.per_device_train_batch_size = args.per_device_train_batch_size
        self.max_prompt_length = args.max_prompt_length
        self.max_completion_length = args.max_completion_length  # = |o_i| in the GRPO paper
        self.num_generations = args.num_generations  # = G in the GRPO paper
        self.max_concurrent = args.max_concurrent
        self.temperature = args.temperature
        self.top_p = args.top_p
        self.min_p = args.min_p
        self.repetition_penalty = args.repetition_penalty
        self.presence_penalty = args.presence_penalty
        self.frequency_penalty = args.frequency_penalty
        self.top_k = args.top_k
        self.loss_type = args.loss_type
        self.use_reward_as_advantage = args.use_reward_as_advantage
        self.scale_rewards = args.scale_rewards
        self.mask_truncated_completions = args.mask_truncated_completions
        self.delta = args.delta

        # Reference model parameters
        self.beta = args.beta
        self.sync_ref_model = args.sync_ref_model
        self.generation_batch_size: int = args.generation_batch_size # type: ignore

        # Multi-step
        self.num_iterations = args.num_iterations  # = 𝜇 in the GRPO paper
        self.epsilon_low = args.epsilon
        self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
        self.gradient_accumulation_steps = args.gradient_accumulation_steps
        self._step = 0
        self._buffered_inputs: Optional[List[Dict[str, Optional[torch.Tensor]]]] = None
        

        # dummy dataset. The parent Trainer needs this to know how many training steps to run
        if args.max_steps <= 0:
            raise ValueError("LudicTrainer requires `max_steps > 0` to be set in the training arguments.")
        
        dummy_len = args.max_steps * args.gradient_accumulation_steps * args.world_size
        dummy_dataset = Dataset.from_dict({"dummy_column": range(dummy_len)})

        # dummy data collator
        def data_collator(features):
            return features
        
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            processing_class=processing_class,
            train_dataset=dummy_dataset,
            eval_dataset=dummy_dataset,
            callbacks=callbacks,
            optimizers=optimizers,
        )

        # Reference model
        if self.beta == 0.0:
            # If beta is 0.0, the reference model is not needed
            self.ref_model = None
        elif is_deepspeed_zero3_enabled():
            model_id = model.config._name_or_path
            model_init_kwargs = {"torch_dtype": "auto"}
            self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
        elif is_peft_model(model):
            # If PEFT is used, the reference model is not needed since the adapter can be disabled
            # to revert to the initial model.
            self.ref_model = None
        else:
            # If PEFT configuration is not provided, create a reference model based on the initial model.
            self.ref_model = create_reference_model(model) # type: ignore

        # Disable dropout in the models
        if args.disable_dropout:
            disable_dropout_in_model(model)
            if self.ref_model is not None:
                disable_dropout_in_model(self.ref_model)

        # Initialize the metrics
        self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
        self._total_train_tokens = 0
        self.log_completions = args.log_completions
        self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
        self.num_completions_to_print = args.num_completions_to_print

        # Environment integration parameters
        self.mask_env_responses = args.mask_env_responses
        self.max_concurrent = args.max_concurrent

        # This will accumulate all textual data between logging steps.
        # Using simple lists is robust to variable-sized batches from rollouts.
        self._textual_logs = {
            "prompt": [],
            "completion": [],
            "rewards": defaultdict(list),
        }

        # OpenAI client for Environment generation (using vLLM server)
        host = args.vllm_server_host
        port = args.vllm_server_port
        vllm_base_url = f"http://{host}:{port}/v1"
        import httpx
        self.oai_client = openai.OpenAI(
            base_url=vllm_base_url,
            api_key="EMPTY",
            http_client=httpx.Client(
                limits=httpx.Limits(max_connections=args.max_concurrent),
                timeout=args.async_generation_timeout)
        )

        self.add_callback(SyncVLLMCallback(self))

        # vLLM client for weight syncing only; only import if used
        from ludic_envs.inference.vllm_client import VLLMClient
        self.vllm_client = VLLMClient(
            host=host,
            port=port,
            connection_timeout=args.vllm_server_timeout
        )
        # Only initialize communicator on the main process
        # Other processes will only use the client for non-NCCL operations
        if self.accelerator.is_main_process:
            self.vllm_client.init_communicator()
        
        self._last_loaded_step = 0  # Initialize to 0 since vLLM already has initial weights
        self.model_accepts_loss_kwargs = False 
        # Weight updates to vLLM happen only when generating new completions
        # Frequency: every (gradient_accumulation_steps * num_iterations) training steps
        # When using vLLM, the main process is responsible for loading the model weights. This can cause process
        # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
        # synchronize all processes after vLLM has been fully initialized.
        self.accelerator.wait_for_everyone()

        # Reference model
        if self.ref_model is not None:
            if self.is_deepspeed_enabled:
                self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
        if self.sync_ref_model:
            self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) # type: ignore

        # Environment
        self.env = env
        self.env_args = args.env_args
        self.max_env_steps = args.max_env_steps

        # Async generation setup 
        self._next_batch_id: int = 0
        self._async_started = False
        self.num_batches_ahead = args.num_batches_ahead
        
        # num_batches_ahead=0 will behave synchronously (submit and wait immediately)
        self.async_generator = LudicAsyncBatchGenerator(
            client=self.oai_client,
            model_name=self._get_model_name(),
            sampling_args=self._get_sampling_args(),
            num_batches_ahead=self.num_batches_ahead, 
            max_queue_size=args.async_max_queue_size,
            generation_timeout=args.async_generation_timeout,
        )


    def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: LudicConfig) -> PreTrainedModel:
        """Enables gradient checkpointing for the model."""
        # Ensure use_cache is disabled
        model.config.use_cache = False

        # Enable gradient checkpointing on the base model for PEFT
        if is_peft_model(model):
            model.base_model.gradient_checkpointing_enable() # type: ignore
        # Enable gradient checkpointing for non-PEFT models
        else:
            model.gradient_checkpointing_enable()

        gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
        assert isinstance(gradient_checkpointing_kwargs, dict)
        use_reentrant = gradient_checkpointing_kwargs.get("use_reentrant", True)

        if use_reentrant:
            model.enable_input_require_grads()

        return model
    
    def _inner_training_loop(self, *args, **kwargs):
        """Override to ensure async generator is stopped when training ends"""
        try:
            return super()._inner_training_loop(*args, **kwargs)
        finally:
            # Clean up async generator on all processes
            if self.async_generator and self._async_started and self.accelerator.is_main_process:
                self.async_generator.stop()
            self._async_started = False

    def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None):
        if is_peft_model(unwrapped_model):
            unwrapped_model = unwrapped_model.base_model.model
        last_hidden_state = unwrapped_model.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        last_hidden_state = last_hidden_state[:, :-1, :]  # (B, L-1, H)
        if logits_to_keep is not None:
            last_hidden_state = last_hidden_state[:, -logits_to_keep:, :]  # (B, logits_to_keep, H)
        return last_hidden_state

    # Get the per-token log probabilities for the completions for the model and the reference model
    def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor:
        batch_size = batch_size or input_ids.size(0)  # Chunk inputs into smaller batches to reduce memory peak
        all_logps = []
        for i in range(0, input_ids.size(0), batch_size):
            input_ids_batch = input_ids[i : i + batch_size]
            attention_mask_batch = attention_mask[i : i + batch_size]

            # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
            logits = model(
                input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1
            ).logits
            logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
            input_ids_batch = input_ids_batch[:, -logits_to_keep:]
            # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
            # See https://github.com/huggingface/trl/issues/2770
            logits = logits[:, -logits_to_keep:]
            # Divide logits by sampling temperature.
            # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
            logits = logits / self.temperature
            logps = selective_log_softmax(logits, input_ids_batch)  # compute logprobs for the input tokens
            all_logps.append(logps)
        return torch.cat(all_logps, dim=0)

    def _move_model_to_vllm(self):
        """
        Correctly syncs model weights to a vLLM server when using DeepSpeed ZeRO-3.
        This version uses the deepspeed library directly for robustness and compatibility.
        """
        # This check is important to know if we need the gather context at all.
        is_zero3 = is_deepspeed_zero3_enabled()

        self.logger.info(f"Process {self.accelerator.process_index}: Entering model weight sync.")
        self.accelerator.wait_for_everyone()

        # The context manager will handle gathering all parameters for ZeRO-3.
        # The 'enabled' flag makes this a no-op if ZeRO-3 is not active.
        with deepspeed.zero.GatheredParameters(self.model.parameters(), enabled=is_zero3):
            # With PEFT, we need to merge the adapter to get the full weights BEFORE sending.
            # This happens inside the gather context to ensure the full model is available.
            if is_peft_model(self.model):
                self.model.merge_adapter()

            # CRITICAL: Only the main process interacts with vLLM.
            if self.accelerator.is_main_process:
                self.logger.info("Rank 0: Syncing gathered parameters to vLLM...")
                
                # Now that all parameters are gathered and available on this process,
                # we can safely loop through them.
                for name, param in self.model.named_parameters():
                    if not param.is_contiguous():
                        param = param.contiguous()
                        
                    # --- PEFT name-mangling logic from your original code ---
                    name = name.removeprefix("base_model.model.").replace(".base_layer", "")
                    if getattr(self.model, "prefix", "___nonexistent___") in name:
                        continue
                    if "original_module" in name:
                        continue
                    name = name.replace("modules_to_save.default.", "")
                    # --- End of PEFT logic ---

                    self.vllm_client.update_named_param(name, param.data)
                
                self.logger.info("Rank 0: Parameter sync complete. Resetting vLLM cache.")
                self.vllm_client.reset_prefix_cache()
                
            # If we merged a PEFT adapter, unmerge it again after syncing.
            # This must also be done inside the context manager.
            if is_peft_model(self.model):
                self.model.unmerge_adapter()

        # Final barrier: Ensures no process continues until the main process has
        # finished all its vLLM communication.
        self.logger.info(f"Process {self.accelerator.process_index}: Exiting model weight sync.")
        #if self.accelerator.is_main_process:
        #    breakpoint()
        self.accelerator.wait_for_everyone()
    
    def _get_sampling_args(self) -> Dict[str, Any]:
        """Get sampling arguments for Environment generation."""
        args = {
            'temperature': self.temperature,
            'top_p': self.top_p,
            'max_tokens': self.max_completion_length,
            'n': 1,
            'presence_penalty': self.presence_penalty,
            'frequency_penalty': self.frequency_penalty,
            'extra_body': {
                'top_k': self.top_k,
                'min_p': self.min_p,
                'repetition_penalty': self.repetition_penalty,
            }
        }
        return args

    def _get_model_name(self) -> str:
        """Get model name for Environment generation."""
        return self.model.config._name_or_path # type: ignore

    def _ids_to_tensors(self,
                        prompt_ids: List[List[int]],
                        prompt_mask: List[List[int]],
                        completion_ids: List[List[int]],
                        completion_mask: List[List[int]],
                        device: torch.device) -> Dict[str, torch.Tensor]:
    
        ids = [prompt_ids[i] + completion_ids[i] for i in range(len(prompt_ids))] 
        mask = [prompt_mask[i] + completion_mask[i] for i in range(len(prompt_mask))]
        max_len = max(len(ids[i]) for i in range(len(ids)))
        ids = [torch.cat([
            torch.tensor(ids[i], dtype=torch.long, device=device),
            torch.zeros(max_len - len(ids[i]), dtype=torch.long, device=device)
        ]) for i in range(len(ids))]
        mask = [torch.cat([
            torch.tensor(mask[i], dtype=torch.long, device=device),
            torch.zeros(max_len - len(mask[i]), dtype=torch.long, device=device)
        ]) for i in range(len(mask))]
        ids = torch.stack(ids, dim=0)   
        mask = torch.stack(mask, dim=0)
        return {
            'ids': ids,
            'mask': mask
        }

    def _prepare_inputs(self, inputs: list[dict[str, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
        """
        Generates and prepares a batch of data for training using the LudicAsyncBatchGenerator.

        This method implements an async pipeline:
        1. Submits requests to the generator to stay `num_batches_ahead`.
        2. Retrieves a completed batch of game rollouts.
        3. Processes, tokenizes, and prepares the data into tensors for the training step.
        """
        self.accelerator.wait_for_everyone()

        generate_every = self.gradient_accumulation_steps * self.num_iterations

        if self._step % generate_every == 0 or self._buffered_inputs is None:
            # ===================================================================
            # --- PART 1: BATCH SUBMISSION ---
            # ===================================================================
            if not self._async_started and self.accelerator.is_main_process:
                self.async_generator.start()
            self._async_started = True
            self.accelerator.wait_for_everyone()

            batch_id_to_retrieve = self._step // generate_every
            target_batch_id = batch_id_to_retrieve + self.num_batches_ahead

            batches_submitted = 0
            for batch_id in range(self._next_batch_id, target_batch_id + 1):
                if self.accelerator.is_main_process:
                    self.logger.info(f"Submitting generation request for batch_id: {batch_id}")
                    request = LudicBatchRequest(
                        batch_id=batch_id,
                        env_class=type(self.env),
                        env_args=self.env_args,
                        batch_size=self.per_device_train_batch_size * self.gradient_accumulation_steps,
                        group_size=self.num_generations,
                        max_steps=self.max_env_steps,
                        processing_class=self.processing_class,
                        max_completion_length=self.max_completion_length,
                        mask_truncated_completions=self.mask_truncated_completions,
                        mask_env_responses=False
                    )
                    self.async_generator.submit_batch(request)
                    batches_submitted += 1

            if self.accelerator.is_main_process and batches_submitted > 0:
                self._next_batch_id += batches_submitted

            next_batch_id_list = [self._next_batch_id if self.accelerator.is_main_process else 0]
            broadcast_object_list(next_batch_id_list, from_process=0)
            self._next_batch_id = next_batch_id_list[0]
            self.accelerator.wait_for_everyone()

            # --- PART 2: BATCH RETRIEVAL ---
            if self.accelerator.is_main_process:
                self.logger.info(f"Retrieving batch {batch_id_to_retrieve} for processing...")
                batch_result = self.async_generator.get_batch(batch_id_to_retrieve)

                processed = batch_result.processed_results
                broadcast_data = {
                    'prompt_ids': processed['prompt_ids'],
                    'prompt_mask': processed['prompt_mask'],
                    'completion_ids': processed['completion_ids'],
                    'completion_mask': processed['completion_mask'],
                    'rewards': processed['rewards'],
                    'rollout_ids': processed['rollout_ids'],
                    'group_ids': processed['group_ids'],
                    'all_reward_dict': batch_result.all_reward_dict,
                    'completions': batch_result.completions,
                    'prompts': batch_result.prompts,
                    'env_stats': batch_result.env_stats
                }

                self._log_generation_to_file(batch_id_to_retrieve, broadcast_data)

            else:
                broadcast_data = None

            # ===================================================================
            # --- PART 3: DATA PROCESSING & TENSORIZATION ---
            # ===================================================================
            broadcast_list = [broadcast_data]
            broadcast_object_list(broadcast_list, from_process=0)
            broadcast_data = broadcast_list[0]
            assert broadcast_data is not None
            self.accelerator.wait_for_everyone()

            # --- Advantage Calculation ---
            all_step_rewards = torch.tensor(broadcast_data['rewards'], device=self.accelerator.device)

            # 1. ALWAYS compute trajectory-level rewards and advantages
            all_rollout_ids = torch.tensor(broadcast_data['rollout_ids'], device=self.accelerator.device)
            num_rollouts = self.args.per_device_train_batch_size * self.args.num_generations * self.gradient_accumulation_steps
            trajectory_rewards = torch.zeros(num_rollouts, device=self.accelerator.device)
            trajectory_rewards.index_put_((all_rollout_ids,), all_step_rewards, accumulate=True)
            trajectory_advantages = self._compute_advantages(trajectory_rewards)

            # 2. THEN, branch to select the final advantages for the loss function
            if self.use_reward_as_advantage:
                all_step_advantages = all_step_rewards
            else:
                all_step_advantages = trajectory_advantages[all_rollout_ids]

            # --- Slicing and Padding for Local Process ---
            num_total_steps = len(broadcast_data['prompt_ids'])
            num_processes = self.accelerator.num_processes
            process_index = self.accelerator.process_index

            steps_per_process = num_total_steps // num_processes
            remainder = num_total_steps % num_processes

            if process_index < remainder:
                start_index = process_index * (steps_per_process + 1)
                end_index = start_index + steps_per_process + 1
            else:
                start_index = process_index * steps_per_process + remainder
                end_index = start_index + steps_per_process

            process_slice = slice(start_index, end_index)

            # Slice all data for the local process
            local_advantages = all_step_advantages[process_slice]
            local_prompt_ids = [torch.tensor(x) for x in broadcast_data['prompt_ids'][process_slice]]
            local_prompt_mask = [torch.tensor(x) for x in broadcast_data['prompt_mask'][process_slice]]
            local_completion_ids = [torch.tensor(x) for x in broadcast_data['completion_ids'][process_slice]]
            local_completion_mask = [torch.tensor(x) for x in broadcast_data['completion_mask'][process_slice]]

            # Pad sequences to create a uniform batch for the model
            prompt_ids = pad(local_prompt_ids, padding_value=self.processing_class.pad_token_id)
            prompt_mask = pad(local_prompt_mask)
            completion_ids = pad(local_completion_ids, padding_value=self.processing_class.pad_token_id)
            completion_mask = pad(local_completion_mask)

            # --- Final Preparation (Truncating, Logging, Buffering) ---
            if self.max_prompt_length is not None and prompt_ids.size(1) > self.max_prompt_length:
                self.logger.warning("MAX PROMPT LENGTH REACHED. TRUNCATING...")
                prompt_ids = prompt_ids[:, -self.max_prompt_length:]
                prompt_mask = prompt_mask[:, -self.max_prompt_length:]

            if self.max_completion_length is not None and completion_ids.size(1) > self.max_completion_length:
                completion_ids = completion_ids[:, :self.max_completion_length]
                completion_mask = completion_mask[:, :self.max_completion_length]

            if self.accelerator.is_main_process:
                # Log quantitative metrics
                self._log_reward_metrics_primary(
                    mode="train",
                    all_reward_dict=broadcast_data['all_reward_dict'],
                    all_rewards=trajectory_rewards,
                    generation_batch_size=self.generation_batch_size
                )
                
                # Combine all per-step rewards into one dictionary for logging.
                text_log_rewards = broadcast_data['all_reward_dict'].copy()
                # Add the final, consolidated per-step reward to the dictionary.
                text_log_rewards['consolidated_reward'] = broadcast_data['rewards']

                # Log qualitative (textual) data using the consistent per-step data.
                self._log_textual_data_primary(
                    all_prompts=broadcast_data['prompts'],
                    all_completions=broadcast_data['completions'],
                    all_reward_dict=text_log_rewards
                )

                # Log completion and environment metrics
                self._log_completion_metrics_primary(
                    mode="train",
                    all_completion_mask=broadcast_data["completion_mask"],
                    all_completion_ids=broadcast_data["completion_ids"],
                    all_prompt_mask=broadcast_data["prompt_mask"],
                )
                if batch_result.env_stats:
                    self.logger.info(f"Logging environment stats: {batch_result.env_stats}")
                    for key, value in batch_result.env_stats.items():
                        self._metrics["train"][f"env/{key}"].append(value)

            full_batch = {
                "prompt_ids": prompt_ids.to(self.accelerator.device),
                "prompt_mask": prompt_mask.to(self.accelerator.device),
                "completion_ids": completion_ids.to(self.accelerator.device),
                "completion_mask": completion_mask.to(self.accelerator.device),
                "old_per_token_logps": None,
                "advantages": local_advantages,
            }

            full_batch = shuffle_tensor_dict(full_batch)
            self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps)
            self.accelerator.wait_for_everyone()

        result = self._buffered_inputs[self._step % self.gradient_accumulation_steps]
        #if self.accelerator.is_main_process:
        #    breakpoint()
        self._step += 1
        self.accelerator.wait_for_everyone()
        return result


    # TODO: Turn into strategy pattern
    def _compute_advantages(
        self,
        rewards: torch.Tensor,
    ) -> torch.Tensor:
        """Compute advantages from rewards with normalization using full batch statistics."""
        # Always use full batch statistics
        mean_grouped = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped = rewards.view(-1, self.num_generations).std(dim=1)
        
        # Normalize the rewards to compute advantages
        mean_grouped = mean_grouped.repeat_interleave(self.num_generations, dim=0)
        std_grouped = std_grouped.repeat_interleave(self.num_generations, dim=0)
        advantages = rewards - mean_grouped
        
        if self.scale_rewards:
            advantages = advantages / (std_grouped + 1e-4)
        
        return advantages


    def compute_loss(self,
                     model: PreTrainedModel,
                     inputs: Dict[str, torch.Tensor],
                     return_outputs: bool = False,
                     num_items_in_batch: int | None = None) -> torch.Tensor: 
        mode = "train" 
        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

        # Compute the loss
        advantages = inputs["advantages"]
        # When using num_iterations == 1, old_per_token_logps == per_token_logps,
        # so we can skip it's computation (see _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = (
            per_token_logps.detach() if inputs["old_per_token_logps"] is None else inputs["old_per_token_logps"]
        )
        coef_1 = torch.exp(per_token_logps - old_per_token_logps)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

        if self.delta is not None:
            # Use clamp instead of min to handle tensor-float comparison
            per_token_loss1 = torch.clamp(coef_1, max=self.delta) * advantages.unsqueeze(1)
        else:
            # Original GRPO clipping (only lower bound implicitly applied by the final min)
            per_token_loss1 = coef_1 * advantages.unsqueeze(1)

        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        
        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            with torch.no_grad():
                if self.ref_model is not None:
                    ref_per_token_logps = self._get_per_token_logps(
                        self.ref_model, input_ids, attention_mask, logits_to_keep
                    )
                else:
                    with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore
                        ref_per_token_logps = self._get_per_token_logps(
                            self.model, input_ids, attention_mask, logits_to_keep
                        )
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )
            per_token_loss = per_token_loss + self.beta * per_token_kl
            mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
            self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item()) # type: ignore

        if self.loss_type == "grpo":
            loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
        elif self.loss_type in ["bnpo", "direct"]:
            loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
        elif self.loss_type == "dr_grpo":
            loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) # type: ignore
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

        # Compute the clipped probability ratios
        is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
        is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
        is_region_clipped = is_low_clipped | is_high_clipped

        low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
        high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
        clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()

        gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
        self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) # type: ignore
        self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) # type: ignore
        gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
        self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) # type: ignore
        self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) # type: ignore
        gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
        self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) # type: ignore

        self.logger.info(f"Computed loss at internal step {self._step}, global step {self.state.global_step}, rank {self.accelerator.process_index}")
        
        return loss
    
    def evaluate(self, ignore_keys=None, metric_key_prefix="eval", **kwargs):
        """
        Runs evaluation and logging strictly on the main process to prevent
        distributed coordination issues and messy logs. Annotated for debugging.
        """
        process_rank = self.accelerator.process_index

        # Any process that is NOT the main process will wait here.
        if not self.accelerator.is_main_process:
            self.logger.info(f"Rank {process_rank}: Waiting at evaluate() entry barrier.")
            self.accelerator.wait_for_everyone()
            self.logger.info(f"Rank {process_rank}: Released from evaluate() entry barrier. Exiting.")
            return {}

        # --- From here on, ONLY THE MAIN PROCESS (rank 0) IS EXECUTING ---
        self.logger.info(f"Rank {process_rank}: Is main process, beginning evaluation.")

        # Clear stale textual logs from previous steps to prevent data mixing
        if self.log_completions:
            self.logger.info(f"Rank {process_rank}: Clearing stale textual logs for evaluation.")
            self._textual_logs["prompt"].clear()
            self._textual_logs["completion"].clear()
            self._textual_logs["rewards"].clear()

        self.logger.info(f"Rank {process_rank}: Starting evaluation rollouts at global_step={self.state.global_step}...")

        metrics = {}
        eval_batch_size = self.args.per_device_eval_batch_size
        sampling_args = self._get_sampling_args()
        sampling_args["temperature"] = 0.6

        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

        rollouts, env_stats = loop.run_until_complete(
            self.env.run_rollouts(
                env_args=self.env_args,
                batch_size=eval_batch_size,
                group_size=1,
                max_steps=self.max_env_steps,
                client=self.oai_client,
                model_name=self._get_model_name(),
                sampling_args=sampling_args,
            )
        )
        self.logger.info(f"Rank {process_rank}: Evaluation rollouts complete.")

        for key, value in env_stats.items():
            metrics[f"env/{key}"] = value

        if self.log_completions:
            self.logger.info(f"Rank {process_rank}: Logging {len(rollouts)} completion examples from evaluation.")
            for rollout in rollouts:
                if not rollout: continue
                initial_prompt = rollout[0].observation
                full_completion = "\n---\n".join(
                    f"Step {i+1}:\n{step.response}" for i, step in enumerate(rollout)
                )
                final_reward = sum(step.reward for step in rollout)
                
                self._textual_logs["prompt"].append(str(initial_prompt))
                self._textual_logs["completion"].append(full_completion)
                self._textual_logs["rewards"]["eval_game_outcome"].append(final_reward)

        prefixed_metrics = {f"{metric_key_prefix}_{k}": v for k, v in metrics.items()}
        
        self.logger.info(f"Rank {process_rank}: Evaluation metrics calculated. Calling self.log().")
        self.log(prefixed_metrics)
        self.logger.info(f"Rank {process_rank}: Returned from self.log().")

        # The final barrier to sync up all processes before the training loop continues.
        self.logger.info(f"Rank {process_rank}: Waiting at evaluate() exit barrier.")
        self.accelerator.wait_for_everyone()
        self.logger.info(f"Rank {process_rank}: Released from evaluate() exit barrier. Returning metrics.")

        return prefixed_metrics

    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        """
        Safely logs metrics in a distributed environment. This version converts the
        metrics defaultdict to a standard dict before gathering to prevent
        serialization errors with `gather_object`.
        """
        #if self.accelerator.is_main_process:
        #    breakpoint()
        
        mode = "eval" if any(key.startswith("eval") for key in logs.keys()) else "train"
        
        # === Step 1: Gather Custom Metrics (Collective Operation) ===
        if mode == "train":
            # Convert the defaultdict to a standard dict before gathering.
            # This prevents the strange serialization behavior you discovered.
            metrics_to_gather = dict(self._metrics[mode])
            
            # Call the standalone utility function `gather_object`.
            gathered_metrics_list = gather_object([metrics_to_gather])
        else:
            # In eval mode, logic is main-process-only, no gathering needed.
            gathered_metrics_list = []

        # === Step 2: Main Process Computes Final Metrics and Performs I/O ===
        if self.accelerator.is_main_process:
            final_custom_metrics = {}
            if mode == "train":
                # Now, `gathered_metrics_list` should correctly be a list of dicts.
                
                # Dynamically find the keys present on ALL processes.
                # We add a type check for safety in case of unexpected gather results.
                key_sets = [set(d.keys()) for d in gathered_metrics_list if isinstance(d, dict)]
                collective_keys = set.intersection(*key_sets) if key_sets else set()
                
                # Process the keys that were present everywhere (the collective keys).
                for key in collective_keys:
                    all_values = [
                        v for d in gathered_metrics_list if isinstance(d, dict)
                        for v in d.get(key, [])
                    ]
                    if all_values:
                        final_custom_metrics[key] = (
                            torch.as_tensor(all_values, dtype=torch.float32).mean().item()
                        )

                # Add the metrics that only existed on the main process (rank 0).
                if gathered_metrics_list and isinstance(gathered_metrics_list[0], dict):
                    main_process_metrics = gathered_metrics_list[0]
                    for key, val_list in main_process_metrics.items():
                        if key not in collective_keys and val_list:
                            final_custom_metrics[key] = (
                                torch.as_tensor(val_list, dtype=torch.float32).mean().item()
                            )
                
                # Update the logs dictionary that will be printed.
                logs.update(final_custom_metrics)

            # Now that `logs` is complete, call the parent logger for console output.
            if start_time is not None:
                super().log(logs, start_time)
            else:
                super().log(logs)

            # --- Wandb Table Logging ---
            if self.log_completions and "wandb" in self.args.report_to:
                if len(self._textual_logs["prompt"]) > 0:
                    import pandas as pd
                    import wandb
                    table = {
                        "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]),
                        "prompt": list(self._textual_logs["prompt"]),
                        "completion": list(self._textual_logs["completion"]),
                        **{k: list(v) for k, v in self._textual_logs["rewards"].items()},
                    }
                    try:
                        df = pd.DataFrame(table)
                        if self.wandb_log_unique_prompts:
                            df = df.drop_duplicates(subset=["prompt"])
                        wandb.log({"completions": wandb.Table(dataframe=df)})
                    except ValueError as e:
                        self.logger.warning(f"Could not create wandb table at step {self.state.global_step}. Error: {e}")

            # Clear textual logs after use.
            self._textual_logs["prompt"].clear()
            self._textual_logs["completion"].clear()
            self._textual_logs["rewards"].clear()

        # === Step 3: Cleanup and Synchronization (All Processes) ===
        # All processes clear their local metric buffers.
        self._metrics[mode].clear()
        
        # The conditional barrier remains crucial.
        if mode == "train":
            self.accelerator.wait_for_everyone()

    def _log_reward_metrics_primary(
        self,
        mode: str,
        all_reward_dict: Dict[str, Any],
        all_rewards: torch.Tensor,
        generation_batch_size: int
    ) -> None:
        """
        Log generation metrics (PRIMARY PROCESS ONLY).
        This handles reward statistics and per-reward-function metrics using the full batch data.
        """
        # Log reward statistics using full batch
        mean_rewards = all_rewards.view(-1, self.num_generations).mean(dim=1)
        std_rewards = all_rewards.view(-1, self.num_generations).std(dim=1)
        self._metrics[mode]["reward"].append(mean_rewards.mean().item())
        self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
        
        # Log individual reward function scores as metrics
        for reward_key in all_reward_dict:
            if reward_key != 'reward':  # Skip the consolidated reward
                reward_values = all_reward_dict[reward_key]
                if isinstance(reward_values, list):
                    reward_tensor = torch.tensor(reward_values, device=all_rewards.device)
                else:
                    reward_tensor = reward_values
                mean_reward = reward_tensor.mean().item()
                self._metrics[mode][f"rewards/{reward_key}"].append(mean_reward)

    def _log_textual_data_primary(
        self,
        all_prompts: List[Union[str, List[Dict[str, Any]]]],
        all_completions: List[Union[str, List[Dict[str, Any]]]],
        all_reward_dict: Dict[str, Any]
    ) -> None:
        """
        Log textual data for wandb (PRIMARY PROCESS ONLY).
        This logs the full batch of prompts, completions, and rewards.
        """
        self._textual_logs["prompt"].extend(all_prompts)
        self._textual_logs["completion"].extend(all_completions)
        
        # Log all reward scores - both individual functions and consolidated
        for reward_key in all_reward_dict:
            reward_values = all_reward_dict[reward_key]
            self._textual_logs["rewards"][reward_key].extend(
                reward_values.tolist() if isinstance(reward_values, torch.Tensor) else reward_values
            )

    def _log_completion_metrics_primary(
        self,
        mode: str,
        all_completion_mask: List[List[int]],
        all_completion_ids: List[List[int]],
        all_prompt_mask: List[List[int]]
    ) -> None:
        """
        Log completion-related metrics (PRIMARY PROCESS ONLY).
        This handles completion length statistics using the full batch data.
        """
        # Log token count, only during training
        if mode == "train":
            total_tokens = sum(len(pm) + len(cm) for pm, cm in zip(all_prompt_mask, all_completion_mask))
            self.state.num_input_tokens_seen += total_tokens
            self._metrics[mode]["num_tokens"].append(self.state.num_input_tokens_seen)
        
        # Log completion lengths
        completion_lengths = [sum(mask) for mask in all_completion_mask]
        self._metrics[mode]["completions/mean_length"].append(float(sum(completion_lengths)) / len(completion_lengths))
        self._metrics[mode]["completions/min_length"].append(float(min(completion_lengths)))
        self._metrics[mode]["completions/max_length"].append(float(max(completion_lengths)))
        
        # Check for EOS tokens  
        term_lengths = []
        for comp_ids, comp_mask in zip(all_completion_ids, all_completion_mask):
            has_eos = any(token == self.processing_class.eos_token_id for token, mask in zip(comp_ids, comp_mask) if mask) # type: ignore
            if has_eos:
                term_lengths.append(sum(comp_mask))
        
        clipped_completions_ratio = 1 - len(term_lengths) / len(completion_lengths) if completion_lengths else 0
        self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
        
        if len(term_lengths) == 0:
            term_lengths = [0]
        self._metrics[mode]["completions/mean_terminated_length"].append(float(sum(term_lengths)) / len(term_lengths))
        self._metrics[mode]["completions/min_terminated_length"].append(float(min(term_lengths)))
        self._metrics[mode]["completions/max_terminated_length"].append(float(max(term_lengths)))

    def _log_generation_to_file(self, batch_id: int, data_to_log: dict) -> None:
        """
        Logs raw generation data to a file for traceability.

        This is a helper method called only on the main process. It serializes the
        data broadcast from the generation process and appends it to a .jsonl file.
        """
        import os
        import json
        try:
            # Use the output directory defined in the training arguments
            log_dir = self.args.output_dir
            if not log_dir:
                self.logger.warning("`output_dir` not set, skipping file logging of generation data.")
                return
            
            os.makedirs(log_dir, exist_ok=True)
            log_file_path = os.path.join(log_dir, "generation_log.jsonl")

            # Prepare the log entry with relevant metadata
            log_entry = {
                "batch_id": batch_id,
                "global_step": self.state.global_step,
                "internal_step": self._step,
                "data": {}
            }

            # Sanitize the data for JSON serialization (e.g., convert tensors to lists)
            for key, value in data_to_log.items():
                if isinstance(value, dict):
                    # Handle nested dictionaries, like all_reward_dict
                    log_entry["data"][key] = {
                        k: v.tolist() if isinstance(v, torch.Tensor) else v
                        for k, v in value.items()
                    }
                elif isinstance(value, torch.Tensor):
                    log_entry["data"][key] = value.tolist()
                else:
                    log_entry["data"][key] = value

            # Write the serialized log entry to the file in append mode
            with open(log_file_path, "a") as f:
                f.write(json.dumps(log_entry) + "\n")

        except Exception as e:
            # Prevent logging failures from crashing the training run
            self.logger.error(f"Failed to log generation data to file: {e}", exc_info=True)

    def set_initial_training_values(
        self, args, dataloader, total_train_batch_size
    ):
        if args.max_steps <= 0:
            raise ValueError("LudicTrainer needs a positive args.max_steps")

        # Use a very large epoch count so the outer epoch loop does not
        # prematurely terminate before `global_step` reaches `max_steps`.
        import sys
        num_train_epochs         = sys.maxsize  # effectively "infinite" epochs
        num_update_steps_epoch   = args.max_steps
        epoch_based              = False            # "step-based" run
        len_dataloader           = None             # length is irrelevant
        num_examples             = total_train_batch_size * args.max_steps
        num_train_samples        = num_examples
        max_steps                = args.max_steps

        return (
            num_train_epochs,
            num_update_steps_epoch,
            num_examples,
            num_train_samples,
            epoch_based,
            len_dataloader,
            max_steps,
        )
===== run_configs/tictactoe_l40_multi.toml =====
# RunPod Orchestrator Configuration for TicTacToe Multi-GPU
# Uses 2x A100 GPUs for high-performance training: 1 for vLLM, 1+ for training

[pod]
pod_type = "NVIDIA A100-SXM4-80GB"  # 80GB SXM
max_hours = 6
volume_size = 60
docker_image = "ghcr.io/joshuapurtell/runpod-ludic:slim-fixed"

[gpus]
per_node = 2  # Minimum required: 1 for vLLM + 1 for training
reserve0 = "vllm"  # GPU 0 for vLLM server, GPU 1 for training

[trainer]
model = "Qwen/Qwen2.5-1.5B-Instruct"
branch = "master"
train_steps = 100
env_name = "TicTacToe"
experiment_name = "tictactoe_multi_gpu_test"
batch_size = 16
gradient_accumulation_steps = 2
learning_rate = 5e-6
num_rollouts_per_device = 32
rollout_gen_n = 4
ppo_epochs = 1
output_dir = "/workspace/outputs"
prune_long_seqs = true

[logging]
wandb_project = "ludic-tictactoe"
log_frequency = 5
rsync_interval = 30

[secrets]
required = ["RUNPOD_API_KEY", "HF_TOKEN", "GH_PAT_RUNPOD"]
optional = ["WANDB_API_KEY"]

cluster_size = 1
===== run_configs/tictactoe_small_multi.toml =====
# RunPod Orchestrator Configuration for TicTacToe Multi-GPU
# Uses 2x L40S GPUs for high-performance training: 1 for vLLM, 1+ for training

[pod]
pod_type = "NVIDIA L40S"  # 48GB GDDR6
max_hours = 6
volume_size = 50
docker_image = "ghcr.io/joshuapurtell/runpod-ludic:slim-fixed"

[gpus]
per_node = 2  # Minimum required: 1 for vLLM + 1 for training
reserve0 = "vllm"  # GPU 0 for vLLM server, GPU 1 for training

[vllm]
# vLLM server configuration optimized for 0.5B model
gpu_memory_utilization = 0.6  # Conservative for small model
max_model_len = 1024          # TicTacToe doesn't need long context
max_batch_size = 8            # Match training batch size
token_chunk_size = 16         # Smaller chunks for faster response
dtype = "bfloat16"            # Memory efficient

[trainer]
model = "Qwen/Qwen2.5-0.5B-Instruct"
branch = "master"
train_steps = 100
env_name = "TicTacToe"
experiment_name = "tictactoe_multi_gpu_test"
batch_size = 8
gradient_accumulation_steps = 8
learning_rate = 1e-6
num_rollouts_per_device = 32
rollout_gen_n = 8
ppo_epochs = 1
output_dir = "/workspace/outputs"
prune_long_seqs = true
max_prompt_length = 1024
max_completion_length = 128

# HuggingFace Hub Upload Configuration
push_to_hub           = false                    # Set to true to upload model to HuggingFace Hub
hub_model_id          = ""                       # e.g., "JoshPurtell/ludic-tictactoe-qwen-0.5b" 
hub_private_repo      = true                     # Whether to create a private repository
hub_strategy          = "end"                    # When to push: "end", "every_save", "checkpoint", "all_checkpoints"

# To enable HuggingFace uploads:
# 1. Set push_to_hub = true
# 2. EITHER set hub_model_id = "JoshPurtell/your-model-name"
#    OR leave hub_model_id empty and set HF_ACCOUNT="https://huggingface.co/JoshPurtell" in your .env
#    (it will auto-generate: "JoshPurtell/ludic-{experiment_name}-{model_name}")
# 3. Ensure HF_TOKEN environment variable is set with your HuggingFace token
# 4. Get your token from: https://huggingface.co/settings/tokens

# Example to enable uploads with auto-generated model name:
# push_to_hub = true
# hub_model_id = ""  # Leave empty to auto-generate from HF_ACCOUNT

[logging]
wandb_project = "ludic-tictactoe"
log_frequency = 5
rsync_interval = 30

[secrets]
required = ["RUNPOD_API_KEY", "HF_TOKEN", "GH_PAT_RUNPOD"]
optional = ["WANDB_API_KEY"]

cluster_size = 1