"""Federated Learning Environment for llama_simulation.

This module provides a simulation environment for federated learning scenarios,
where multiple clients collaborate to train a global model while keeping their
data private.
"""

import random
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from loguru import logger

from llama_simulation.environments.base import Environment


@dataclass
class FederatedClient:
    """Representation of a client in federated learning.
    
    Attributes:
        id: Unique identifier for this client.
        data: Client's local data.
        model: Client's local model.
        optimizer: Client's local optimizer.
        train_stats: Statistics from local training.
    """
    
    id: str
    data: Dict[str, Any]
    model: Optional[nn.Module] = None
    optimizer: Optional[optim.Optimizer] = None
    train_stats: Dict[str, List[float]] = field(default_factory=lambda: {
        "loss": [],
        "accuracy": [],
        "communication_cost": [],
    })


class FederatedLearningEnv(Environment):
    """Environment for simulating federated learning scenarios.
    
    This environment simulates a federated learning process, where multiple
    clients collaborate to train a global model while keeping their data private.
    
    Attributes:
        num_clients: Number of clients participating in federated learning.
        communication_rounds: Number of communication rounds.
        data_distribution: Type of data distribution across clients.
        model_architecture: Neural network architecture for the global model.
        global_model: Global model being trained.
        clients: List of clients participating in federated learning.
        fraction_fit: Fraction of clients to select for each round.
        aggregation_method: Method for aggregating client updates.
        current_round: Current communication round.
    """
    
    def __init__(
        self,
        num_clients: int = 10,
        communication_rounds: int = 10,
        data_distribution: str = "iid",
        model_architecture: str = "mlx_simple_cnn",
        client_optimizer: str = "sgd",
        fraction_fit: float = 1.0,
        aggregation_method: str = "fedavg",
        local_epochs: int = 1,
        batch_size: int = 32,
        metrics: List[str] = ["accuracy", "loss", "communication_cost"],
        seed: Optional[int] = None,
        id: Optional[str] = None,
        name: Optional[str] = None,
    ) -> None:
        """Initialize a new FederatedLearningEnv.
        
        Args:
            num_clients: Number of clients participating in federated learning.
            communication_rounds: Number of communication rounds.
            data_distribution: Type of data distribution across clients.
                Options: "iid", "non_iid", "pathological_non_iid", "custom".
            model_architecture: Neural network architecture for the global model.
                Options: "mlx_simple_cnn", "mlx_resnet18", "mlx_transformer", "custom".
            client_optimizer: Optimizer to use for client training.
                Options: "sgd", "adam", "rmsprop".
            fraction_fit: Fraction of clients to select for each round.
            aggregation_method: Method for aggregating client updates.
                Options: "fedavg", "fedprox", "fedopt", "scaffold".
            local_epochs: Number of local training epochs per round.
            batch_size: Batch size for local training.
            metrics: List of metrics to track.
            seed: Random seed.
            id: Unique identifier for this environment.
            name: Human-readable name of this environment.
        """
        super().__init__(
            id=id, 
            name=name or f"FederatedLearning-{data_distribution}-{model_architecture}",
            state_dim=1 + num_clients * 3,  # round + client stats
            action_dim=1,  # client selection strategy
            num_agents=1,  # centralized orchestration
            continuous_actions=False,
        )
        
        # Set random seed
        if seed is not None:
            self.seed(seed)
        
        # Store parameters
        self.num_clients = num_clients
        self.communication_rounds = communication_rounds
        self.data_distribution = data_distribution
        self.model_architecture = model_architecture
        self.client_optimizer_type = client_optimizer
        self.fraction_fit = fraction_fit
        self.aggregation_method = aggregation_method
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        
        # Initialize state
        self.current_round = 0
        self.global_model = None
        self.clients = []
        
        # Set up metrics tracking
        for metric in metrics:
            self.metrics[metric] = []
        
        self._initialize_global_model()
        self._initialize_clients()
        logger.info(f"Initialized FederatedLearningEnv with {num_clients} clients")
    
    def _initialize_global_model(self) -> None:
        """Initialize the global model based on the specified architecture."""
        if self.model_architecture == "mlx_simple_cnn":
            # Simple CNN for image classification
            self.global_model = nn.Sequential(
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten(),
                nn.Linear(128 * 4 * 4, 512),
                nn.ReLU(),
                nn.Linear(512, 10),
            )
        elif self.model_architecture == "mlx_resnet18":
            # Import and configure ResNet18
            from llama_simulation.models.resnet import ResNet18
            self.global_model = ResNet18(num_classes=10)
        elif self.model_architecture == "mlx_transformer":
            # Simple transformer for sequence tasks
            from llama_simulation.models.transformer import SimpleTransformer
            self.global_model = SimpleTransformer(
                vocab_size=10000,
                d_model=128,
                nhead=4,
                num_encoder_layers=3,
                num_decoder_layers=3,
                dim_feedforward=512,
                dropout=0.1,
            )
        else:
            raise ValueError(f"Unknown model architecture: {self.model_architecture}")
        
        # Initialize model parameters
        dummy_input = mx.random.normal((1, 3, 32, 32))
        _ = self.global_model(dummy_input)
        logger.info(f"Initialized global model: {self.model_architecture}")
    
    def _initialize_clients(self) -> None:
        """Initialize clients with data distribution."""
        # Create synthetic dataset based on data distribution
        dataset = self._create_dataset()
        
        # Distribute data among clients
        client_data = self._distribute_data(dataset)
        
        # Create client instances
        self.clients = []
        for i in range(self.num_clients):
            client_id = f"client_{i}"
            client = FederatedClient(
                id=client_id,
                data=client_data[i],
            )
            
            # Clone global model for client
            client.model = self._clone_model(self.global_model)
            
            # Create optimizer
            if self.client_optimizer_type == "sgd":
                client.optimizer = optim.SGD(learning_rate=0.01)
            elif self.client_optimizer_type == "adam":
                client.optimizer = optim.Adam(learning_rate=0.001)
            elif self.client_optimizer_type == "rmsprop":
                client.optimizer = optim.RMSProp(learning_rate=0.001)
            else:
                raise ValueError(f"Unknown optimizer: {self.client_optimizer_type}")
            
            self.clients.append(client)
        
        logger.info(f"Initialized {len(self.clients)} clients with {self.data_distribution} data distribution")
    
    def _create_dataset(self) -> Dict[str, Any]:
        """Create a synthetic dataset for federated learning simulation.
        
        Returns:
            Dict[str, Any]: Synthetic dataset.
        """
        # Create a synthetic dataset (e.g., MNIST-like)
        num_samples = 1000 * self.num_clients
        num_classes = 10
        
        # Generate random inputs and labels
        inputs = mx.random.normal((num_samples, 3, 32, 32))
        labels = mx.random.randint(0, num_classes, (num_samples,))
        
        return {
            "inputs": inputs,
            "labels": labels,
            "num_classes": num_classes,
            "num_samples": num_samples,
        }
    
    def _distribute_data(self, dataset: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Distribute data among clients according to specified distribution.
        
        Args:
            dataset: Complete dataset to distribute.
        
        Returns:
            List[Dict[str, Any]]: List of client datasets.
        """
        inputs = dataset["inputs"]
        labels = dataset["labels"]
        num_samples = dataset["num_samples"]
        num_classes = dataset["num_classes"]
        
        client_data = []
        
        if self.data_distribution == "iid":
            # IID: randomly distribute data among clients
            indices = np.random.permutation(num_samples)
            samples_per_client = num_samples // self.num_clients
            
            for i in range(self.num_clients):
                start_idx = i * samples_per_client
                end_idx = (i + 1) * samples_per_client if i < self.num_clients - 1 else num_samples
                client_indices = indices[start_idx:end_idx]
                
                client_data.append({
                    "inputs": inputs[client_indices],
                    "labels": labels[client_indices],
                    "num_samples": len(client_indices),
                })
        
        elif self.data_distribution == "non_iid":
            # Non-IID: each client gets data from a Dirichlet distribution over classes
            alpha = 0.5  # Concentration parameter (lower = more non-IID)
            client_proportions = np.random.dirichlet(alpha=alpha * np.ones(self.num_clients), size=num_classes)
            
            # Get indices for each class
            class_indices = [np.where(labels == c)[0] for c in range(num_classes)]
            
            # Distribute data
            client_indices = [[] for _ in range(self.num_clients)]
            
            for c, indices in enumerate(class_indices):
                # Shuffle indices for this class
                np.random.shuffle(indices)
                
                # Split according to proportions
                proportions = client_proportions[c]
                proportions = proportions / proportions.sum()  # Normalize
                cumsum = np.cumsum(proportions)
                
                # Distribute indices
                start_idx = 0
                for i in range(self.num_clients):
                    end_idx = int(cumsum[i] * len(indices))
                    client_indices[i].extend(indices[start_idx:end_idx])
                    start_idx = end_idx
            
            # Create client datasets
            for i in range(self.num_clients):
                client_indices[i] = np.array(client_indices[i])
                client_data.append({
                    "inputs": inputs[client_indices[i]],
                    "labels": labels[client_indices[i]],
                    "num_samples": len(client_indices[i]),
                })
        
        elif self.data_distribution == "pathological_non_iid":
            # Pathological non-IID: each client gets data from only certain classes
            shards_per_client = 2
            num_shards = self.num_clients * shards_per_client
            
            # Sort data by label
            sorted_indices = np.argsort(labels)
            sorted_inputs = inputs[sorted_indices]
            sorted_labels = labels[sorted_indices]
            
            # Divide data into shards
            samples_per_shard = num_samples // num_shards
            shard_indices = [
                sorted_indices[i * samples_per_shard:(i + 1) * samples_per_shard]
                for i in range(num_shards)
            ]
            
            # Randomly assign shards to clients
            random.shuffle(shard_indices)
            client_shard_indices = [
                shard_indices[i * shards_per_client:(i + 1) * shards_per_client]
                for i in range(self.num_clients)
            ]
            
            # Create client datasets
            for i in range(self.num_clients):
                # Combine shards
                indices = np.concatenate(client_shard_indices[i])
                
                client_data.append({
                    "inputs": inputs[indices],
                    "labels": labels[indices],
                    "num_samples": len(indices),
                })
        
        else:
            raise ValueError(f"Unknown data distribution: {self.data_distribution}")
        
        return client_data
    
    def _clone_model(self, model: nn.Module) -> nn.Module:
        """Create a deep copy of a model with the same architecture and parameters.
        
        Args:
            model: Model to clone.
        
        Returns:
            nn.Module: Cloned model.
        """
        # MLX models can be cloned by creating a new instance with the same structure
        # and then copying the parameters
        
        # Create a new instance with the same structure
        if self.model_architecture == "mlx_simple_cnn":
            cloned_model = nn.Sequential(
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten(),
                nn.Linear(128 * 4 * 4, 512),
                nn.ReLU(),
                nn.Linear(512, 10),
            )
        elif self.model_architecture == "mlx_resnet18":
            from llama_simulation.models.resnet import ResNet18
            cloned_model = ResNet18(num_classes=10)
        elif self.model_architecture == "mlx_transformer":
            from llama_simulation.models.transformer import SimpleTransformer
            cloned_model = SimpleTransformer(
                vocab_size=10000,
                d_model=128,
                nhead=4,
                num_encoder_layers=3,
                num_decoder_layers=3,
                dim_feedforward=512,
                dropout=0.1,
            )
        else:
            raise ValueError(f"Unknown model architecture: {self.model_architecture}")
        
        # Initialize with dummy input
        dummy_input = mx.random.normal((1, 3, 32, 32))
        _ = cloned_model(dummy_input)
        
        # Copy parameters
        params = dict(model.parameters())
        for key, value in cloned_model.parameters().items():
            if key in params:
                setattr(cloned_model, key.split(".")[0], params[key])
        
        return cloned_model
    
    def configure_federation(
        self,
        aggregation_method: Optional[str] = None,
        client_optimizer: Optional[str] = None,
        local_epochs: Optional[int] = None,
    ) -> None:
        """Configure federated learning parameters.
        
        Args:
            aggregation_method: Method for aggregating client updates.
            client_optimizer: Optimizer to use for client training.
            local_epochs: Number of local training epochs per round.
        """
        if aggregation_method is not None:
            self.aggregation_method = aggregation_method
        
        if client_optimizer is not None:
            self.client_optimizer_type = client_optimizer
            
            # Update client optimizers
            for client in self.clients:
                if client_optimizer == "sgd":
                    client.optimizer = optim.SGD(learning_rate=0.01)
                elif client_optimizer == "adam":
                    client.optimizer = optim.Adam(learning_rate=0.001)
                elif client_optimizer == "rmsprop":
                    client.optimizer = optim.RMSProp(learning_rate=0.001)
                else:
                    raise ValueError(f"Unknown optimizer: {client_optimizer}")
        
        if local_epochs is not None:
            self.local_epochs = local_epochs
        
        logger.info(f"Configured federated learning: {self.aggregation_method}, {self.client_optimizer_type}, {self.local_epochs} epochs")
    
    def reset(self) -> Dict[str, Any]:
        """Reset the environment to its initial state.
        
        Returns:
            Dict[str, Any]: Initial observation.
        """
        super().reset()
        
        # Reset state
        self.current_round = 0
        
        # Reset global model
        self._initialize_global_model()
        
        # Reset clients
        for client in self.clients:
            client.model = self._clone_model(self.global_model)
            client.train_stats = {
                "loss": [],
                "accuracy": [],
                "communication_cost": [],
            }
        
        # Return initial observation
        return self._get_observation()
    
    def step(self, actions: Union[int, List[int]]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
        """Take a step in the federated learning process.
        
        Args:
            actions: Action(s) to take.
                In this environment, the action determines the client selection strategy.
                - 0: Random selection
                - 1: Round-robin selection
                - 2: Importance-based selection
        
        Returns:
            Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
                - observation: Next observation.
                - reward: Reward received.
                - done: Whether the episode is done.
                - info: Additional information.
        """
        super().step(actions)
        
        # Convert action to client selection strategy
        if isinstance(actions, list):
            action = actions[0]
        else:
            action = actions
        
        # Select clients for this round
        selected_clients = self._select_clients(action)
        
        # Train selected clients
        client_updates = self._train_clients(selected_clients)
        
        # Aggregate client updates
        self._aggregate_updates(client_updates)
        
        # Evaluate global model
        metrics = self._evaluate_global_model()
        
        # Update metrics history
        for name, value in metrics.items():
            self.add_metric(name, value)
        
        # Move to next round
        self.current_round += 1
        
        # Check if done
        done = self.current_round >= self.communication_rounds
        
        # Calculate reward based on improvement in accuracy
        reward = metrics["accuracy"]
        if len(self.metrics["accuracy"]) > 1:
            reward -= self.metrics["accuracy"][-2]  # Improvement over last round
        
        # Create info dictionary
        info = {
            "metrics": metrics,
            "selected_clients": [client.id for client in selected_clients],
            "round": self.current_round,
        }
        
        return self._get_observation(), reward, done, info
    
    def _select_clients(self, strategy: int) -> List[FederatedClient]:
        """Select clients for the current round based on strategy.
        
        Args:
            strategy: Client selection strategy.
                - 0: Random selection
                - 1: Round-robin selection
                - 2: Importance-based selection
        
        Returns:
            List[FederatedClient]: Selected clients.
        """
        num_to_select = max(1, int(self.fraction_fit * self.num_clients))
        
        if strategy == 0:
            # Random selection
            indices = np.random.choice(self.num_clients, num_to_select, replace=False)
            selected_clients = [self.clients[i] for i in indices]
        
        elif strategy == 1:
            # Round-robin selection
            start_idx = (self.current_round * num_to_select) % self.num_clients
            indices = [(start_idx + i) % self.num_clients for i in range(num_to_select)]
            selected_clients = [self.clients[i] for i in indices]
        
        elif strategy == 2:
            # Importance-based selection (prioritize clients with more data or higher loss)
            if self.current_round == 0:
                # First round, select based on data size
                importance = [client.data["num_samples"] for client in self.clients]
            else:
                # Later rounds, select based on loss
                importance = [
                    client.train_stats["loss"][-1] if client.train_stats["loss"] else 0
                    for client in self.clients
                ]
            
            # Convert to probabilities
            total = sum(importance)
            if total > 0:
                probs = [imp / total for imp in importance]
            else:
                probs = None
            
            indices = np.random.choice(self.num_clients, num_to_select, replace=False, p=probs)
            selected_clients = [self.clients[i] for i in indices]
        
        else:
            # Default to random selection
            logger.warning(f"Unknown client selection strategy: {strategy}, using random selection")
            indices = np.random.choice(self.num_clients, num_to_select, replace=False)
            selected_clients = [self.clients[i] for i in indices]
        
        logger.debug(f"Selected {len(selected_clients)} clients for round {self.current_round}")
        return selected_clients
    
    def _train_clients(self, clients: List[FederatedClient]) -> List[Dict[str, Any]]:
        """Train selected clients on their local data.
        
        Args:
            clients: Clients to train.
        
        Returns:
            List[Dict[str, Any]]: Updates from each client.
        """
        client_updates = []
        
        for client in clients:
            # Get client data
            inputs = client.data["inputs"]
            labels = client.data["labels"]
            
            # Calculate initial model size (for communication cost)
            initial_model_size = self._calculate_model_size(client.model)
            
            # Train for local epochs
            for epoch in range(self.local_epochs):
                # Shuffle data
                indices = np.random.permutation(len(inputs))
                
                # Train in batches
                num_batches = len(indices) // self.batch_size
                total_loss = 0.0
                correct = 0
                
                for batch in range(num_batches):
                    start_idx = batch * self.batch_size
                    end_idx = min((batch + 1) * self.batch_size, len(indices))
                    batch_indices = indices[start_idx:end_idx]
                    
                    batch_inputs = inputs[batch_indices]
                    batch_labels = labels[batch_indices]
                    
                    # Forward pass
                    def loss_fn(model):
                        logits = model(batch_inputs)
                        loss = nn.losses.cross_entropy(logits, batch_labels)
                        preds = mx.argmax(logits, axis=1)
                        acc = mx.mean(preds == batch_labels)
                        return loss, (loss, acc)
                    
                    # Compute gradients
                    grads, (loss, acc) = nn.value_and_grad(loss_fn, has_aux=True)(client.model)
                    
                    # Update parameters
                    client.optimizer.update(client.model, grads)
                    
                    # Track metrics
                    total_loss += loss.item()
                    correct += acc.item() * len(batch_indices)
            
            # Calculate training stats
            avg_loss = total_loss / num_batches
            accuracy = correct / len(inputs)
            
            # Calculate communication cost (model size difference)
            final_model_size = self._calculate_model_size(client.model)
            communication_cost = final_model_size
            
            # Update client stats
            client.train_stats["loss"].append(avg_loss)
            client.train_stats["accuracy"].append(accuracy)
            client.train_stats["communication_cost"].append(communication_cost)
            
            # Collect client update
            update = {
                "client_id": client.id,
                "model": client.model,
                "metrics": {
                    "loss": avg_loss,
                    "accuracy": accuracy,
                    "communication_cost": communication_cost,
                },
            }
            client_updates.append(update)
        
        logger.debug(f"Trained {len(clients)} clients for round {self.current_round}")
        return client_updates
    
    def _aggregate_updates(self, client_updates: List[Dict[str, Any]]) -> None:
        """Aggregate client updates to update the global model.
        
        Args:
            client_updates: Updates from each client.
        """
        if not client_updates:
            logger.warning("No client updates to aggregate")
            return
        
        if self.aggregation_method == "fedavg":
            # FedAvg: weighted average of client models
            total_samples = sum(client.data["num_samples"] for client in self.clients)
            
            # Initialize new parameters with zeros
            new_params = {}
            for key, value in self.global_model.parameters().items():
                new_params[key] = mx.zeros_like(value)
            
            # Weighted sum of client parameters
            for update in client_updates:
                client_id = update["client_id"]
                client = next(c for c in self.clients if c.id == client_id)
                weight = client.data["num_samples"] / total_samples
                
                for key, value in update["model"].parameters().items():
                    new_params[key] += value * weight
            
            # Update global model parameters
            for key, value in new_params.items():
                self.global_model.update_parameter(key, value)
        
        elif self.aggregation_method == "fedprox":
            # FedProx: similar to FedAvg but with proximal term
            # (Implementation simplified for this example)
            total_samples = sum(client.data["num_samples"] for client in self.clients)
            
            # Initialize new parameters with zeros
            new_params = {}
            for key, value in self.global_model.parameters().items():
                new_params[key] = mx.zeros_like(value)
            
            # Weighted sum of client parameters
            for update in client_updates:
                client_id = update["client_id"]
                client = next(c for c in self.clients if c.id == client_id)
                weight = client.data["num_samples"] / total_samples
                
                for key, value in update["model"].parameters().items():
                    global_value = self.global_model.parameters()[key]
                    # Add proximal term (simplification)
                    prox_value = 0.9 * value + 0.1 * global_value
                    new_params[key] += prox_value * weight
            
            # Update global model parameters
            for key, value in new_params.items():
                self.global_model.update_parameter(key, value)
        
        else:
            # Default to FedAvg
            logger.warning(f"Unknown aggregation method: {self.aggregation_method}, using FedAvg")
            total_samples = sum(client.data["num_samples"] for client in self.clients)
            
            # Initialize new parameters with zeros
            new_params = {}
            for key, value in self.global_model.parameters().items():
                new_params[key] = mx.zeros_like(value)
            
            # Weighted sum of client parameters
            for update in client_updates:
                client_id = update["client_id"]
                client = next(c for c in self.clients if c.id == client_id)
                weight = client.data["num_samples"] / total_samples
                
                for key, value in update["model"].parameters().items():
                    new_params[key] += value * weight
            
            # Update global model parameters
            for key, value in new_params.items():
                self.global_model.update_parameter(key, value)
        
        # Update client models with the new global model
        for client in self.clients:
            client.model = self._clone_model(self.global_model)
        
        logger.debug(f"Aggregated updates using {self.aggregation_method} for round {self.current_round}")
    
    def _evaluate_global_model(self) -> Dict[str, float]:
        """Evaluate the global model on all client data.
        
        Returns:
            Dict[str, float]: Evaluation metrics.
        """
        total_samples = 0
        total_loss = 0.0
        total_correct = 0
        
        for client in self.clients:
            inputs = client.data["inputs"]
            labels = client.data["labels"]
            
            # Forward pass
            logits = self.global_model(inputs)
            loss = nn.losses.cross_entropy(logits, labels).item()
            preds = mx.argmax(logits, axis=1)
            correct = mx.sum(preds == labels).item()
            
            # Accumulate metrics
            total_samples += len(inputs)
            total_loss += loss * len(inputs)
            total_correct += correct
        
        # Calculate metrics
        avg_loss = total_loss / total_samples
        accuracy = total_correct / total_samples
        
        # Calculate communication cost
        communication_cost = sum(
            client.train_stats["communication_cost"][-1]
            for client in self.clients
            if client.train_stats["communication_cost"]
        ) / self.num_clients
        
        metrics = {
            "loss": avg_loss,
            "accuracy": accuracy,
            "communication_cost": communication_cost,
        }
        
        logger.debug(f"Global model evaluation: loss={avg_loss:.4f}, accuracy={accuracy:.4f}")
        return metrics
    
    def _get_observation(self) -> Dict[str, Any]:
        """Get the current observation of the environment.
        
        Returns:
            Dict[str, Any]: Current observation.
        """
        # Observation includes round number and client stats
        observation = {
            "round": self.current_round,
            "clients": [
                