"""Transformer model implementation using MLX.

This module provides an implementation of the Transformer architecture
using MLX for acceleration.
"""

from typing import Dict, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn


class PositionalEncoding(nn.Module):
    """Positional encoding for the Transformer model.
    
    This implements the sinusoidal positional encoding described in
    "Attention Is All You Need" paper.
    
    Attributes:
        d_model: Dimension of the model.
        max_len: Maximum sequence length.
        pe: Positional encoding.
    """
    
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1) -> None:
        """Initialize a new PositionalEncoding.
        
        Args:
            d_model: Dimension of the model.
            max_len: Maximum sequence length.
            dropout: Dropout rate.
        """
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        # Compute positional encoding
        position = mx.expand_dims(mx.arange(0, max_len, dtype=mx.float32), axis=1)
        div_term = mx.exp(mx.arange(0, d_model, 2, dtype=mx.float32) * (-mx.log(10000.0) / d_model))
        
        pe = mx.zeros((max_len, d_model))
        pe = pe.at[:, 0::2].set(mx.sin(position * div_term))
        pe = pe.at[:, 1::2].set(mx.cos(position * div_term))
        
        self.pe = mx.expand_dims(pe, axis=0)  # Shape: (1, max_len, d_model)
    
    def __call__(self, x: mx.array) -> mx.array:
        """Add positional encoding to the input.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model).
        
        Returns:
            mx.array: Output tensor with positional encoding added.
        """
        seq_len = x.shape[1]
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)


class TransformerEncoder(nn.Module):
    """Transformer encoder implementation.
    
    This implements the encoder part of the Transformer model described in
    "Attention Is All You Need" paper.
    
    Attributes:
        layers: List of encoder layers.
        norm: Layer normalization.
    """
    
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        dim_feedforward: int,
        dropout: float = 0.1,
    ) -> None:
        """Initialize a new TransformerEncoder.
        
        Args:
            d_model: Dimension of the model.
            nhead: Number of attention heads.
            num_layers: Number of encoder layers.
            dim_feedforward: Dimension of the feedforward network.
            dropout: Dropout rate.
        """
        super().__init__()
        
        # Create encoder layers
        encoder_layers = []
        for _ in range(num_layers):
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model,
                n_head=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
            )
            encoder_layers.append(encoder_layer)
        
        self.layers = encoder_layers
        self.norm = nn.LayerNorm(d_model)
    
    def __call__(self, src: mx.array, src_mask: Optional[mx.array] = None) -> mx.array:
        """Forward pass through the encoder.
        
        Args:
            src: Source sequence.
            src_mask: Mask for source sequence.
        
        Returns:
            mx.array: Encoded sequence.
        """
        output = src
        
        for layer in self.layers:
            output = layer(output, src_mask=src_mask)
        
        output = self.norm(output)
        return output


class TransformerDecoder(nn.Module):
    """Transformer decoder implementation.
    
    This implements the decoder part of the Transformer model described in
    "Attention Is All You Need" paper.
    
    Attributes:
        layers: List of decoder layers.
        norm: Layer normalization.
    """
    
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        dim_feedforward: int,
        dropout: float = 0.1,
    ) -> None:
        """Initialize a new TransformerDecoder.
        
        Args:
            d_model: Dimension of the model.
            nhead: Number of attention heads.
            num_layers: Number of decoder layers.
            dim_feedforward: Dimension of the feedforward network.
            dropout: Dropout rate.
        """
        super().__init__()
        
        # Create decoder layers
        decoder_layers = []
        for _ in range(num_layers):
            decoder_layer = nn.TransformerDecoderLayer(
                d_model=d_model,
                n_head=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
            )
            decoder_layers.append(decoder_layer)
        
        self.layers = decoder_layers
        self.norm = nn.LayerNorm(d_model)
    
    def __call__(
        self,
        tgt: mx.array,
        memory: mx.array,
        tgt_mask: Optional[mx.array] = None,
        memory_mask: Optional[mx.array] = None,
    ) -> mx.array:
        """Forward pass through the decoder.
        
        Args:
            tgt: Target sequence.
            memory: Memory from the encoder.
            tgt_mask: Mask for target sequence.
            memory_mask: Mask for memory.
        
        Returns:
            mx.array: Decoded sequence.
        """
        output = tgt
        
        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        
        output = self.norm(output)
        return output


class SimpleTransformer(nn.Module):
    """Simple Transformer implementation for sequence tasks.
    
    This implements a simple Transformer model for sequence-to-sequence tasks,
    as described in "Attention Is All You Need" paper.
    
    Attributes:
        d_model: Dimension of the model.
        src_embed: Source embedding.
        tgt_embed: Target embedding.
        pos_encoder: Positional encoding.
        encoder: Transformer encoder.
        decoder: Transformer decoder.
        output_layer: Output projection layer.
    """
    
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        nhead: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforwar