"""Adversarial Attack Generator module.

This module provides tools for generating adversarial examples to test
the robustness of machine learning models, with a focus on MLX integration.
"""

import os
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from loguru import logger


class AttackType(Enum):
    """Types of adversarial attacks.
    
    Attributes:
        FGSM: Fast Gradient Sign Method.
        PGD: Projected Gradient Descent.
        CW: Carlini-Wagner attack.
        BOUNDARY: Boundary attack.
        DEEPFOOL: DeepFool attack.
    """
    
    FGSM = "fgsm"
    PGD = "pgd"
    CW = "cw"
    BOUNDARY = "boundary"
    DEEPFOOL = "deepfool"


class AdversarialAttackGenerator:
    """Generator for adversarial examples to test model robustness.
    
    This class provides methods for generating adversarial examples using
    various attack methods, with MLX integration for acceleration.
    
    Attributes:
        epsilon: Maximum perturbation size.
        attack_type: Type of adversarial attack.
        targeted: Whether the attack is targeted.
        random_start: Whether to use random initialization for iterative attacks.
        model: Model to attack.
        loss_fn: Loss function for generating adversarial examples.
    """
    
    def __init__(
        self,
        epsilon: float = 0.1,
        attack_type: Union[str, AttackType] = AttackType.FGSM,
        targeted: bool = False,
        random_start: bool = True,
        clip_min: float = 0.0,
        clip_max: float = 1.0,
        max_iter: int = 10,
    ) -> None:
        """Initialize a new AdversarialAttackGenerator.
        
        Args:
            epsilon: Maximum perturbation size.
            attack_type: Type of adversarial attack.
            targeted: Whether the attack is targeted.
            random_start: Whether to use random initialization for iterative attacks.
            clip_min: Minimum value for clipping.
            clip_max: Maximum value for clipping.
            max_iter: Maximum number of iterations for iterative attacks.
        """
        self.epsilon = epsilon
        
        if isinstance(attack_type, str):
            try:
                self.attack_type = AttackType(attack_type.lower())
            except ValueError:
                logger.warning(f"Unknown attack type: {attack_type}, defaulting to FGSM")
                self.attack_type = AttackType.FGSM
        else:
            self.attack_type = attack_type
        
        self.targeted = targeted
        self.random_start = random_start
        self.clip_min = clip_min
        self.clip_max = clip_max
        self.max_iter = max_iter
        
        # Will be set when generating adversarial examples
        self.model = None
        self.loss_fn = None
        
        logger.info(f"Initialized {self.attack_type.value.upper()} attack generator")
        logger.info(f"Parameters: epsilon={epsilon}, targeted={targeted}, random_start={random_start}")
    
    def generate(
        self,
        model: nn.Module,
        inputs: mx.array,
        targets: mx.array,
        loss_fn: Optional[Callable] = None,
        **kwargs,
    ) -> Tuple[mx.array, Dict[str, Any]]:
        """Generate adversarial examples.
        
        Args:
            model: Model to attack.
            inputs: Original inputs.
            targets: Target labels.
                For untargeted attacks, these are the true labels.
                For targeted attacks, these are the target labels.
            loss_fn: Loss function for generating adversarial examples.
                If None, cross-entropy loss will be used.
            **kwargs: Additional attack-specific parameters.
        
        Returns:
            Tuple[mx.array, Dict[str, Any]]:
                - Adversarial examples.
                - Attack statistics.
        """
        self.model = model
        
        if loss_fn is None:
            # Default to cross-entropy loss
            self.loss_fn = lambda logits, labels: nn.losses.cross_entropy(logits, labels)
        else:
            self.loss_fn = loss_fn
        
        # Select attack method
        if self.attack_type == AttackType.FGSM:
            adv_x, stats = self._fgsm_attack(inputs, targets, **kwargs)
        elif self.attack_type == AttackType.PGD:
            adv_x, stats = self._pgd_attack(inputs, targets, **kwargs)
        elif self.attack_type == AttackType.CW:
            adv_x, stats = self._cw_attack(inputs, targets, **kwargs)
        elif self.attack_type == AttackType.BOUNDARY:
            adv_x, stats = self._boundary_attack(inputs, targets, **kwargs)
        elif self.attack_type == AttackType.DEEPFOOL:
            adv_x, stats = self._deepfool_attack(inputs, targets, **kwargs)
        else:
            logger.warning(f"Unknown attack type: {self.attack_type}, defaulting to FGSM")
            adv_x, stats = self._fgsm_attack(inputs, targets, **kwargs)
        
        return adv_x, stats
    
    def _fgsm_attack(
        self, inputs: mx.array, targets: mx.array, **kwargs
    ) -> Tuple[mx.array, Dict[str, Any]]:
        """Implement Fast Gradient Sign Method (FGSM) attack.
        
        Args:
            inputs: Original inputs.
            targets: Target labels.
            **kwargs: Additional parameters.
        
        Returns:
            Tuple[mx.array, Dict[str, Any]]:
                - Adversarial examples.
                - Attack statistics.
        """
        # Compute gradients
        def loss_fn(x):
            logits = self.model(x)
            loss = self.loss_fn(logits, targets)
            return loss
        
        value, grad = mx.value_and_grad(loss_fn)(inputs)
        
        # Adjust sign based on targeted/untargeted
        if self.targeted:
            # For targeted attacks, we want to minimize the loss
            perturbation = -self.epsilon * mx.sign(grad)
        else:
            # For untargeted attacks, we want to maximize the loss
            perturbation = self.epsilon * mx.sign(grad)
        
        # Apply perturbation
        adv_x = inputs + perturbation
        
        # Clip to valid range
        adv_x = mx.clip(adv_x, self.clip_min, self.clip_max)
        
        # Compute statistics
        l2_norm = mx.sqrt(mx.sum((adv_x - inputs) ** 2, axis=tuple(range(1, inputs.ndim)))).mean()
        linf_norm = mx.max(mx.abs(adv_x - inputs), axis=tuple(range(1, inputs.ndim))).mean()
        
        # Evaluate adversarial examples
        adv_logits = self.model(adv_x)
        adv_preds = mx.argmax(adv_logits, axis=1)
        
        if self.targeted:
            success_rate = mx.mean(adv_preds == targets).item()
        else:
            orig_logits = self.model(inputs)
            orig_preds = mx.argmax(orig_logits, axis=1)
            success_rate = mx.mean(adv_preds != orig_preds).item()
        
        stats = {
            "attack_type": self.attack_type.value,
            "loss": value.item(),
            "l2_norm": l2_norm.item(),
            "linf_norm": linf_norm.item(),
            "success_rate": success_rate,
        }
        
        return adv_x, stats
    
    def _pgd_attack(
        self, inputs: mx.array, targets: mx.array, **kwargs
    ) -> Tuple[mx.array, Dict[str, Any]]:
        """Implement Projected Gradient Descent (PGD) attack.
        
        Args:
            inputs: Original inputs.
            targets: Target labels.
            **kwargs: Additional parameters.
        
        Returns:
            Tuple[mx.array, Dict[str, Any]]:
                - Adversarial examples.
                - Attack statistics.
        """
        alpha = kwargs.get("alpha", self.epsilon / 5)
        num_steps = kwargs.get("num_steps", self.max_iter)
        
        # Initialize adversarial examples
        if self.random_start:
            noise = mx.random.uniform(inputs.shape, minval=-self.epsilon, maxval=self.epsilon)
            adv_x = inputs + noise
            adv_x = mx.clip(adv_x, self.clip_min, self.clip_max)
        else:
            adv_x = inputs.copy()
        
        for i in range(num_steps):
            # Compute gradients
            def loss_fn(x):
                logits = self.model(x)
                loss = self.loss_fn(logits, targets)
                return loss
            
            value, grad = mx.value_and_grad(loss_fn)(adv_x)
            
            # Adjust sign based on targeted/untargeted
            if self.targeted:
                # For targeted attacks, we want to minimize the loss
                perturbation = -alpha * mx.sign(grad)
            else:
                # For untargeted attacks, we want to maximize the loss
                perturbation = alpha * mx.sign(grad)
            
            # Apply perturbation
            adv_x = adv_x + perturbation
            
            # Project back to epsilon ball
            delta = mx.clip(adv_x - inputs, -self.epsilon, self.epsilon)
            adv_x = inputs + delta
            
            # Clip to valid range
            adv_x = mx.clip(adv_x, self.clip_min, self.clip_max)
        
        # Compute statistics
        l2_norm = mx.sqrt(mx.sum((adv_x - inputs) ** 2, axis=tuple(range(1, inputs.ndim)))).mean()
        linf_norm = mx.max(mx.abs(adv_x - inputs), axis=tuple(range(1, inputs.ndim))).mean()
        
        # Evaluate adversarial examples
        adv_logits = self.model(adv_x)
        adv_preds = mx.argmax(adv_logits, axis=1)
        
        if self.targeted:
            success_rate = mx.mean(adv_preds == targets).item()
        else:
            orig_logits = self.model(inputs)
            orig_preds = mx.argmax(orig_logits, axis=1)
            success_rate = mx.mean(adv_preds != orig_preds).item()
        
        stats = {
            "attack_type": self.attack_type.value,
            "loss": value.item(),
            "l2_norm": l2_norm.item(),
            "linf_norm": linf_norm.item(),
            "success_rate": success_rate,
            "num_steps": num_steps,
        }
        
        return adv_x, stats
    
    def _cw_attack(
        self, inputs: mx.array, targets: mx.array, **kwargs
    ) -> Tuple[mx.array, Dict[str, Any]]:
        """Implement Carlini-Wagner (CW) attack.
        
        Args:
            inputs: Original inputs.
            targets: Target labels.
            **kwargs: Additional parameters.
        
        Returns:
            Tuple[mx.array, Dict[str, Any]]:
                - Adversarial examples.
                - Attack statistics.
        """
        num_steps = kwargs.get("num_steps", self.max_iter)
        c = kwargs.get("c", 1.0)
        learning_rate = kwargs.get("learning_rate", 0.01)
        
        #