"""Differentially Private Synthetic Data Generation module.

This module provides tools for generating synthetic data with differential privacy
guarantees using MLX-based generative models.
"""

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import pandas as pd
from loguru import logger


@dataclass
class DPParameters:
    """Parameters for differential privacy.
    
    Attributes:
        epsilon: Privacy budget parameter.
        delta: Probability of privacy violation.
        noise_multiplier: Scale of noise to add.
        max_grad_norm: Maximum L2 norm of gradients.
        sample_rate: Sampling rate for privacy accounting.
    """
    
    epsilon: float = 1.0
    delta: float = 1e-5
    noise_multiplier: Optional[float] = None
    max_grad_norm: float = 1.0
    sample_rate: float = 0.01


class MLXTransformerGAN(nn.Module):
    """MLX-based Transformer GAN for synthetic data generation.
    
    This model uses a transformer-based architecture for both the generator
    and discriminator to create synthetic tabular data.
    
    Attributes:
        input_dim: Dimension of the input data.
        hidden_dim: Dimension of hidden layers.
        num_heads: Number of attention heads.
        num_layers: Number of transformer layers.
        dropout: Dropout rate.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 256,
        num_heads: int = 4,
        num_layers: int = 3,
        dropout: float = 0.1,
    ) -> None:
        """Initialize a new MLXTransformerGAN.
        
        Args:
            input_dim: Dimension of the input data.
            hidden_dim: Dimension of hidden layers.
            num_heads: Number of attention heads.
            num_layers: Number of transformer layers.
            dropout: Dropout rate.
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.dropout = dropout
        
        # Generator
        self.latent_dim = hidden_dim
        
        self.generator_embedding = nn.Linear(self.latent_dim, hidden_dim)
        generator_layers = []
        
        for _ in range(num_layers):
            transformer_layer = self._build_transformer_block(hidden_dim, num_heads, dropout)
            generator_layers.append(transformer_layer)
        
        self.generator_layers = nn.Sequential(*generator_layers)
        self.generator_output = nn.Linear(hidden_dim, input_dim)
        
        # Discriminator
        self.discriminator_embedding = nn.Linear(input_dim, hidden_dim)
        discriminator_layers = []
        
        for _ in range(num_layers):
            transformer_layer = self._build_transformer_block(hidden_dim, num_heads, dropout)
            discriminator_layers.append(transformer_layer)
        
        self.discriminator_layers = nn.Sequential(*discriminator_layers)
        self.discriminator_output = nn.Linear(hidden_dim, 1)
    
    def _build_transformer_block(
        self, hidden_dim: int, num_heads: int, dropout: float
    ) -> nn.Module:
        """Build a transformer block.
        
        Args:
            hidden_dim: Dimension of hidden layers.
            num_heads: Number of attention heads.
            dropout: Dropout rate.
        
        Returns:
            nn.Module: Transformer block.
        """
        return nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.MultiHeadAttention(hidden_dim, num_heads=num_heads, dropout=dropout),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim),
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim * 4, hidden_dim),
                nn.Dropout(dropout),
            ),
        )
    
    def generator(self, z: mx.array) -> mx.array:
        """Generate synthetic data from latent vectors.
        
        Args:
            z: Latent vectors of shape (batch_size, latent_dim).
        
        Returns:
            mx.array: Generated data of shape (batch_size, input_dim).
        """
        x = self.generator_embedding(z)
        x = self.generator_layers(x)
        x = self.generator_output(x)
        return x
    
    def discriminator(self, x: mx.array) -> mx.array:
        """Discriminate between real and synthetic data.
        
        Args:
            x: Data of shape (batch_size, input_dim).
        
        Returns:
            mx.array: Discrimination scores of shape (batch_size, 1).
        """
        h = self.discriminator_embedding(x)
        h = self.discriminator_layers(h)
        h = self.discriminator_output(h)
        return h
    
    def sample(self, batch_size: int) -> mx.array:
        """Generate a batch of synthetic samples.
        
        Args:
            batch_size: Number of samples to generate.
        
        Returns:
            mx.array: Generated samples of shape (batch_size, input_dim).
        """
        z = mx.random.normal((batch_size, self.latent_dim))
        return self.generator(z)


class MLXVAESynthesizer(nn.Module):
    """MLX-based Variational Autoencoder for synthetic data generation.
    
    This model uses a VAE architecture to generate synthetic tabular data.
    
    Attributes:
        input_dim: Dimension of the input data.
        hidden_dims: Dimensions of hidden layers.
        latent_dim: Dimension of the latent space.
        dropout: Dropout rate.
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int] = [256, 128],
        latent_dim: int = 64,
        dropout: float = 0.1,
    ) -> None:
        """Initialize a new MLXVAESynthesizer.
        
        Args:
            input_dim: Dimension of the input data.
            hidden_dims: Dimensions of hidden layers.
            latent_dim: Dimension of the latent space.
            dropout: Dropout rate.
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.latent_dim = latent_dim
        self.dropout = dropout
        
        # Encoder
        encoder_layers = []
        prev_dim = input_dim
        
        for dim in hidden_dims:
            encoder_layers.append(nn.Linear(prev_dim, dim))
            encoder_layers.append(nn.LayerNorm(dim))
            encoder_layers.append(nn.GELU())
            encoder_layers.append(nn.Dropout(dropout))
            prev_dim = dim
        
        self.encoder = nn.Sequential(*encoder_layers)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
        
        # Decoder
        decoder_layers = []
        prev_dim = latent_dim
        
        for dim in reversed(hidden_dims):
            decoder_layers.append(nn.Linear(prev_dim, dim))
            decoder_layers.append(nn.LayerNorm(dim))
            decoder_layers.append(nn.GELU())
            decoder_layers.append(nn.Dropout(dropout))
            prev_dim = dim
        
        self.decoder = nn.Sequential(*decoder_layers)
        self.final_layer = nn.Linear(hidden_dims[0], input_dim)
    
    def encode(self, x: mx.array) -> Tuple[mx.array, mx.array]:
        """Encode input data to latent space.
        
        Args:
            x: Input data of shape (batch_size, input_dim).
        
        Returns:
            Tuple[mx.array, mx.array]: Mean and log variance of the latent distribution.
        """
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var
    
    def reparameterize(self, mu: mx.array, log_var: mx.array) -> mx.array:
        """Reparameterize the latent distribution.
        
        Args:
            mu: Mean of the latent distribution.
            log_var: Log variance of the latent distribution.
        
        Returns:
            mx.array: Sampled latent vectors.
        """
        std = mx.exp(0.5 * log_var)
        eps = mx.random.normal(mu.shape)
        z = mu + eps * std
        return z
    
    def decode(self, z: mx.array) -> mx.array:
        """Decode latent vectors to data space.
        
        Args:
            z: Latent vectors of shape (batch_size, latent_dim).
        
        Returns:
            mx.array: Decoded data of shape (batch_size, input_dim).
        """
        h = self.decoder(z)
        return self.final_layer(h)
    
    def forward(self, x: mx.array) -> Tuple[mx.array, mx.array, mx.array]:
        """Forward pass of the VAE.
        
        Args:
            x: Input data of shape (batch_size, input_dim).
        
        Returns:
            Tuple[mx.array, mx.array, mx.array]:
                - Reconstructed data.
                - Mean of the latent distribution.
                - Log variance of the latent distribution.
        """
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        return x_recon, mu, log_var
    
    def sample(self, batch_size: int) -> mx.array:
        """Generate a batch of synthetic samples.
        
        Args:
            batch_size: Number of samples to generate.
        
        Returns:
            mx.array: Generated samples of shape (batch_size, input_dim).
        """
        z = mx.random.normal((batch_size, self.latent_dim))
        return self.decode(z)


class DPMLX:
    """Differential Privacy wrapper for MLX models.
    
    This class implements differential privacy techniques for training
    MLX models, including gradient clipping and noising.
    
    Attributes:
        model: MLX model to train with DP.
        optimizer: MLX optimizer.
        dp_params: Differential privacy parameters.
    """
    
    def __init__(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        dp_params: DPParameters,
    ) -> None:
        """Initialize a new DPMLX wrapper.
        
        Args:
            model: MLX model to train with DP.
            optimizer: MLX optimizer.
            dp_params: Differential privacy parameters.
        """
        self.model = model
        self.optimizer = optimizer
        self.dp_params = dp_params
        
        # Compute noise multiplier if not provided
        if self.dp_params.noise_multiplier is None:
            self.dp_params.noise_multiplier = self._compute_noise_multiplier(
                self.dp_params.epsilon,
                self.dp_params.delta,
                self.dp_params.sample_rate,
            )
    
    def _compute_noise_multiplier(
        self, epsilon: float, delta: float, sample_rate: float
    ) -> float:
        """Compute the noise multiplier for a given privacy budget.
        
        This is a simplified approximation based on the moments accountant method.
        
        Args:
            epsilon: Privacy budget parameter.
            delta: Probability of privacy violation.
            sample_rate: Sampling rate for privacy accounting.
        
        Returns:
            float: Noise multiplier.
        """
        # Simplified formula based on moments accountant
        # In practice, one would use a proper DP library like Opacus
        c = np.sqrt(2 * np.log(1.25 / delta))
        return c * sample_rate / epsilon
    
    def clip_and_add_noise(self, grads: Dict[str, mx.array]) -> Dict[str, mx.array]:
        """Clip gradients and add noise for differential privacy.
        
        Args:
            grads: Gradients dictionary.
        
        Returns:
            Dict[str, mx.array]: Clipped and noised gradients.
        """
        # Compute gradient norms
        total_norm = 0.0
        for name, grad in grads.items():
            if grad is not None:
                param_norm = mx.sum(grad ** 2).item()
                total_norm += param_norm
        
        total_norm = np.sqrt(total_norm)
        clip_factor = min(self.dp_params.max_grad_norm / (total_norm + 1e-6), 1.0)
        
        # Clip gradients
        clipped_grads = {}
        for name, grad in grads.items():
            if grad is not None:
                clipped_grads[name] = grad * clip_factor
            else:
                clipped_grads[name] = None
        
        # Add noise
        noised_grads = {}
        for name, grad in clipped_grads.items():
            if grad is not None:
                noise = mx.random.normal(
                    grad.shape,
                    scale=self.dp_params.noise_multiplier * self.dp_params.max_grad_norm,
                )
                noised_grads[name] = grad + noise
            else:
                noised_grads[name] = None
        
        return noised_grads
    
    def compute_loss(self, loss_fn: callable, *args) -> Tuple[float, Dict[str, mx.array]]:
        """Compute loss and DP gradients.
        
        Args:
            loss_fn: Loss function.
            *args: Arguments to pass to the loss function.
        
        Returns:
            Tuple[float, Dict[str, mx.array]]: Loss value and DP gradients.
        """
        # Compute gradients
        loss, grads = nn.value_and_grad(loss_fn)(*args)
        
        # Apply DP
        dp_grads = self.clip_and_add_noise(grads)
        
        return loss, dp_grads
    
    def step(self, loss_fn: callable, *args) -> float:
        """Perform a DP training step.
        
        Args:
            loss_fn: Loss function.
            *args: Arguments to pass to the loss function.
        
        Returns:
            float: Loss value.
        """
        loss, dp_grads = self.compute_loss(loss_fn, *args)
        self.optimizer.update(self.model, dp_grads)
        return loss.item()


class DPSyntheticDataGenerator:
    """Differentially Private Synthetic Data Generator.
    
    This class provides tools for generating synthetic data with differential
    privacy guarantees using MLX-based generative models.
    
    Attributes:
        epsilon: Privacy budget parameter.
        delta: Probability of privacy violation.
        model_type: Type of generative model to use.
        input_dim: Dimension of the input data.
        max_training_steps: Maximum number of training steps.
        batch_size: Batch size for training.
        model: Generative model.
        dp_params: Differential privacy parameters.
        data_processor: Data preprocessing utility.
    """
    
    def __init__(
        self,
        epsilon: float = 1.0,
        delta: float = 1e-5,
        model_type: str = "mlx_transformer_gan",
        input_dim: int = 768,
        max_training_steps: int = 10000,
        batch_size: int = 64,
        dp_params: Optional[DPParameters] = None,
        seed: Optional[int] = None,
    ) -> None:
        """Initialize a new DPSyntheticDataGenerator.
        
        Args:
            epsilon: Privacy budget parameter.
            delta: Probability of privacy violation.
            model_type: Type of generative model to use.
                Options: "mlx_transformer_gan", "mlx_vae", "custom".
            input_dim: Dimension of the input data.
            max_training_steps: Maximum number of training steps.
            batch_size: Batch size for training.
            dp_params: Differential privacy parameters.
                If None, default parameters will be used.
            seed: Random seed.
        """
        # Set random seed
        if seed is not None:
            np.random.seed(seed)
            mx.random.seed(seed)
        
        self.epsilon = epsilon
        self.delta = delta
        self.model_type = model_type
        self.input_dim = input_dim
        self.max_training_steps = max_training_steps
        self.batch_size = batch_size
        
        # Initialize DP parameters
        self.dp_params = dp_params or DPParameters(
            epsilon=epsilon,
            delta=delta,
            sample_rate=batch_size / 1000,  # Assuming dataset size ~1000
        )
        
        # Initialize generative model
        self.model = self._initialize_model()
        
        # Initialize data processor
        self.data_processor = DataProcessor()
        
        logger.info(f"Initialized DPSyntheticDataGenerator with {model_type} model")
        logger.info(f"DP parameters: epsilon={epsilon}, delta={delta}")
    
    def _initialize_model(self) -> nn.Module:
        """Initialize generative model based on specified type.
        
        Returns:
            nn.Module: Initialized generative model.
        """
        if self.model_type == "mlx_transformer_gan":
            return MLXTransformerGAN(
                input_dim=self.input_dim,
                hidden_dim=256,
                num_heads=4,
                num_layers=3,
                dropout=0.1,
            )
        elif self.model_type == "mlx_vae":
            return MLXVAESynthesizer(
                input_dim=self.input_dim,
                hidden_dims=[256, 128],
                latent_dim=64,
                dropout=0.1,
            )
        else:
            raise ValueError(f"Unknown model type: {self.model_type}")
    
    def train(self, data: Union[pd.DataFrame, np.ndarray, mx.array]) -> Dict[str, List[float]]:
        """Train the generative model with differential privacy.
        
        Args:
            data: Training data. Can be a DataFrame, numpy array, or MLX array.
        
        Returns:
            Dict[str, List[float]]: Training metrics.
        """
        # Preprocess data
        if isinstance(data, pd.DataFrame):
            processed_data = self.data_processor.preprocess_dataframe(data)
        else:
            processed_data = data if isinstance(data, mx.array) else mx.array(data)
        
        # Set up training
        metrics = {"loss": [], "step": []}
        
        if self.model_type == "mlx_transformer_gan":
            # Train GAN model
            optimizer_g = optim.Adam(learning_rate=1e-4)
            optimizer_d = optim.Adam(learning_rate=1e-4)
            
            # Wrap with DP
            dp_g = DPMLX(self.model, optimizer_g, self.dp_params)
            dp_d = DPMLX(self.model, optimizer_d, self.dp_params)
            
            # Training loop
            for step in range(self.max_training_steps):
                # Sample real data
                idx = np.random.choice(len(processed_data), self.batch_size, replace=False)
                real_data = processed_data[idx]
                
                # Train discriminator
                def d_loss_fn(model):
                    # Real data
                    real_pred = model.discriminator(real_data)
                    
                    # Fake data
                    z = mx.random.normal((self.batch_size, model.latent_dim))
                    fake_data = model.generator(z)
                    fake_pred = model.discriminator(fake_data)
                    
                    # Compute loss
                    loss_real = -mx.mean(mx.log(mx.sigmoid(real_pred) + 1e-8))
                    loss_fake = -mx.mean(mx.log(1 - mx.sigmoid(fake_pred) + 1e-8))
                    loss = loss_real + loss_fake
                    
                    return loss
                
                d_loss = dp_d.step(d_loss_fn, self.model)
                
                # Train generator
                def g_loss_fn(model):
                    z = mx.random.normal((self.batch_size, model.latent_dim))
                    fake_data = model.generator(z)
                    fake_pred = model.discriminator(fake_data)
                    
                    # Compute loss
                    loss = -mx.mean(mx.log(mx.sigmoid(fake_pred) + 1e-8))
                    
                    return loss
                
                g_loss = dp_g.step(g_loss_fn, self.model)
                
                # Track metrics
                if step % 100 == 0:
                    metrics["loss"].append((d_loss + g_loss) / 2)
                    metrics["step"].append(step)
                    logger.debug(f"Step {step}: D loss={d_loss:.4f}, G loss={g_loss:.4f}")
        
        elif self.model_type == "mlx_vae":
            # Train VAE model
            optimizer = optim.Adam(learning_rate=1e-4)
            
            # Wrap with DP
            dp_vae = DPMLX(self.model, optimizer, self.dp_params)
            
            # Training loop
            for step in range(self.max_training_steps):
                # Sample data
                idx = np.random.choice(len(processed_data), self.batch_size, replace=False)
                data_batch = processed_data[idx]
                
                # Train VAE
                def vae_loss_fn(model):
                    x_recon, mu, log_var = model(data_batch)
                    
                    # Reconstruction loss
                    recon_loss = mx.mean((x_recon - data_batch) ** 2)
                    
                    # KL divergence
                    kl_div = -0.5 * mx.mean(1 + log_var - mu ** 2 - mx.exp(log_var))
                    
                    # Total loss
                    loss = recon_loss + 0.1 * kl_div
                    
                    return loss
                
                loss = dp_vae.step(vae_loss_fn, self.model)
                
                # Track metrics
                if step % 100 == 0:
                    metrics["loss"].append(loss)
                    metrics["step"].append(step)
                    logger.debug(f"Step {step}: Loss={loss:.4f}")
        
        logger.info(f"Finished training after {self.max