#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
#   "fastmcp>=2.0.0",
#   "sounddevice",
#   "numpy",
#   "scipy",
#   "openai>=1.0.0",
#   "python-dotenv",
#   "livekit-agents>=0.11",
#   "livekit-plugins-openai",
#   "livekit-plugins-silero",
#   "pydub",
# ]
# requires-python = "~=3.11.0"
# ///
"""
Unified Voice MCP Server

A Model Context Protocol server that provides voice interaction capabilities
through multiple transport methods (local microphone or LiveKit rooms).

Features:
- Local voice recording and playback
- LiveKit room-based voice communication
- Configurable OpenAI-compatible STT/TTS services
- Automatic transport selection and fallback
- Privacy-conscious with clear user consent requirements

Environment Variables:
- STT_BASE_URL: Base URL for speech-to-text service (default: OpenAI)
- TTS_BASE_URL: Base URL for text-to-speech service (default: OpenAI)
- TTS_VOICE: Voice to use for TTS (default: nova)
- TTS_MODEL: TTS model to use (default: tts-1)
- STT_MODEL: STT model to use (default: whisper-1)
- OPENAI_API_KEY: API key for OpenAI services
- LIVEKIT_URL: LiveKit server URL (for LiveKit transport)
- LIVEKIT_API_KEY: LiveKit API key
- LIVEKIT_API_SECRET: LiveKit API secret
- VOICE_MCP_DEBUG: Enable debug mode (saves audio files, verbose logging)
"""

import asyncio
import logging
import os
import tempfile
import time
import shutil
from datetime import datetime
from typing import Optional, Literal
from pathlib import Path
from dotenv import load_dotenv

import sounddevice as sd
import numpy as np
from scipy.io.wavfile import write
from pydub import AudioSegment
from openai import AsyncOpenAI

from fastmcp import FastMCP

# Load environment variables
load_dotenv(dotenv_path=".env.local")

# Debug configuration
DEBUG = os.getenv("VOICE_MCP_DEBUG", "").lower() in ("true", "1", "yes", "on")
DEBUG_DIR = Path.home() / "voice-mcp_recordings"

if DEBUG:
    DEBUG_DIR.mkdir(exist_ok=True)

# Configure logging
log_level = logging.DEBUG if DEBUG else logging.INFO
logging.basicConfig(
    level=log_level,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("voice-mcp")

# Suppress verbose binary data in HTTP logs
if DEBUG:
    # Keep our debug logs but reduce HTTP client verbosity
    logging.getLogger("openai._base_client").setLevel(logging.INFO)
    logging.getLogger("httpcore").setLevel(logging.INFO)
    logging.getLogger("httpx").setLevel(logging.INFO)

# Create MCP server
mcp = FastMCP("Voice MCP")

# Audio configuration
SAMPLE_RATE = 44100
CHANNELS = 1

# Service configuration
STT_BASE_URL = os.getenv("STT_BASE_URL", "https://api.openai.com/v1")
TTS_BASE_URL = os.getenv("TTS_BASE_URL", "https://api.openai.com/v1")
TTS_VOICE = os.getenv("TTS_VOICE", "nova")
TTS_MODEL = os.getenv("TTS_MODEL", "tts-1")
STT_MODEL = os.getenv("STT_MODEL", "whisper-1")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# LiveKit configuration
LIVEKIT_URL = os.getenv("LIVEKIT_URL", "ws://localhost:7880")
LIVEKIT_API_KEY = os.getenv("LIVEKIT_API_KEY", "devkey")
LIVEKIT_API_SECRET = os.getenv("LIVEKIT_API_SECRET", "secret")

if not OPENAI_API_KEY:
    raise ValueError("OPENAI_API_KEY is required")


def get_debug_filename(prefix: str, extension: str) -> str:
    """Generate debug filename with timestamp"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]  # milliseconds
    return f"{timestamp}-{prefix}.{extension}"


def save_debug_file(data: bytes, prefix: str, extension: str) -> Optional[str]:
    """Save debug file if debug mode is enabled"""
    if not DEBUG:
        return None
    
    try:
        filename = get_debug_filename(prefix, extension)
        filepath = DEBUG_DIR / filename
        
        with open(filepath, 'wb') as f:
            f.write(data)
        
        logger.debug(f"Debug file saved: {filepath}")
        return str(filepath)
    except Exception as e:
        logger.error(f"Failed to save debug file: {e}")
        return None


def get_openai_clients():
    """Initialize OpenAI clients for STT and TTS"""
    return {
        'stt': AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=STT_BASE_URL),
        'tts': AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=TTS_BASE_URL)
    }


logger.info("✓ MP3 support available (Python 3.11 + pydub)")

# Initialize clients
openai_clients = get_openai_clients()


async def text_to_speech(text: str) -> bool:
    """Convert text to speech and play it"""
    logger.info(f"TTS: Converting text to speech: '{text[:100]}{'...' if len(text) > 100 else ''}'")
    if DEBUG:
        logger.debug(f"TTS full text: {text}")
        logger.debug(f"TTS config - Model: {TTS_MODEL}, Voice: {TTS_VOICE}, Base URL: {TTS_BASE_URL}")
    
    try:
        # Use MP3 format for bandwidth efficiency
        audio_format = "mp3"
        
        logger.debug("Making TTS API request...")
        response = await openai_clients['tts'].audio.speech.create(
            model=TTS_MODEL,
            input=text,
            voice=TTS_VOICE,
            response_format=audio_format
        )
        
        logger.debug(f"TTS API response received, content length: {len(response.content)} bytes")
        
        # Save debug file if enabled
        debug_path = save_debug_file(response.content, "tts-output", audio_format)
        if debug_path:
            logger.info(f"TTS audio saved to: {debug_path}")
        
        # Play audio
        with tempfile.NamedTemporaryFile(suffix=f'.{audio_format}', delete=False) as tmp_file:
            tmp_file.write(response.content)
            tmp_file.flush()
            
            logger.debug(f"Audio written to temp file: {tmp_file.name}")
            
            try:
                # Load MP3 and convert to numpy array for playback
                logger.debug("Loading MP3 audio...")
                audio = AudioSegment.from_mp3(tmp_file.name)
                logger.debug(f"Audio loaded - Duration: {len(audio)}ms, Channels: {audio.channels}, Frame rate: {audio.frame_rate}")
                
                # Convert to numpy array
                logger.debug("Converting to numpy array...")
                samples = np.array(audio.get_array_of_samples())
                if audio.channels == 2:
                    samples = samples.reshape((-1, 2))
                    logger.debug("Reshaped for stereo")
                
                # Convert to float32 for sounddevice
                samples = samples.astype(np.float32) / 32767.0
                logger.debug(f"Audio converted to float32, shape: {samples.shape}")
                
                # Check audio devices
                if DEBUG:
                    try:
                        devices = sd.query_devices()
                        default_output = sd.default.device[1]
                        logger.debug(f"Default output device: {default_output} - {devices[default_output]['name'] if default_output is not None else 'None'}")
                    except Exception as dev_e:
                        logger.error(f"Error querying audio devices: {dev_e}")
                
                logger.debug(f"Playing audio with sounddevice at {audio.frame_rate}Hz...")
                sd.play(samples, audio.frame_rate)
                sd.wait()
                
                logger.info("✓ TTS played successfully")
                os.unlink(tmp_file.name)
                return True
                
            except Exception as e:
                logger.error(f"Error playing audio: {e}")
                logger.error(f"Audio format - Channels: {audio.channels if 'audio' in locals() else 'unknown'}, Frame rate: {audio.frame_rate if 'audio' in locals() else 'unknown'}")
                logger.error(f"Samples shape: {samples.shape if 'samples' in locals() else 'unknown'}")
                
                # Try alternative playback method in debug mode
                if DEBUG:
                    try:
                        logger.debug("Attempting alternative playback with system command...")
                        import subprocess
                        result = subprocess.run(['paplay', tmp_file.name], capture_output=True, timeout=10)
                        if result.returncode == 0:
                            logger.info("✓ Alternative playback successful")
                            os.unlink(tmp_file.name)
                            return True
                        else:
                            logger.error(f"Alternative playback failed: {result.stderr.decode()}")
                    except Exception as alt_e:
                        logger.error(f"Alternative playback error: {alt_e}")
                
                os.unlink(tmp_file.name)
                return False
                        
    except Exception as e:
        logger.error(f"TTS failed: {e}")
        logger.error(f"TTS config when error occurred - Model: {TTS_MODEL}, Voice: {TTS_VOICE}, Base URL: {TTS_BASE_URL}")
        if hasattr(e, 'response'):
            logger.error(f"HTTP status: {e.response.status_code if hasattr(e.response, 'status_code') else 'unknown'}")
            logger.error(f"Response text: {e.response.text if hasattr(e.response, 'text') else 'unknown'}")
        return False


async def speech_to_text(audio_data: np.ndarray) -> Optional[str]:
    """Convert audio to text"""
    logger.info(f"STT: Converting speech to text, audio data shape: {audio_data.shape}")
    if DEBUG:
        logger.debug(f"STT config - Model: {STT_MODEL}, Base URL: {STT_BASE_URL}")
        logger.debug(f"Audio stats - Min: {audio_data.min()}, Max: {audio_data.max()}, Mean: {audio_data.mean():.2f}")
    
    with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as wav_file:
        logger.debug(f"Writing audio to WAV file: {wav_file.name}")
        write(wav_file.name, SAMPLE_RATE, audio_data)
        
        # Save debug file for original recording
        if DEBUG:
            try:
                with open(wav_file.name, 'rb') as f:
                    debug_path = save_debug_file(f.read(), "stt-input", "wav")
                    if debug_path:
                        logger.info(f"Original recording saved to: {debug_path}")
            except Exception as e:
                logger.error(f"Failed to save debug WAV: {e}")
        
        try:
            # Convert WAV to MP3 for smaller upload
            logger.debug("Converting WAV to MP3 for upload...")
            audio = AudioSegment.from_wav(wav_file.name)
            logger.debug(f"Audio loaded - Duration: {len(audio)}ms, Channels: {audio.channels}, Frame rate: {audio.frame_rate}")
            
            mp3_file = tempfile.NamedTemporaryFile(suffix='.mp3', delete=False)
            audio.export(mp3_file.name, format="mp3", bitrate="64k")
            upload_file = mp3_file.name
            logger.debug(f"MP3 created for STT upload: {upload_file}")
            
            # Save debug file for upload version
            if DEBUG:
                try:
                    with open(upload_file, 'rb') as f:
                        debug_path = save_debug_file(f.read(), "stt-upload", "mp3")
                        if debug_path:
                            logger.info(f"Upload audio saved to: {debug_path}")
                except Exception as e:
                    logger.error(f"Failed to save debug MP3: {e}")
            
            # Get file size for logging
            file_size = os.path.getsize(upload_file)
            logger.debug(f"Uploading {file_size} bytes to STT API...")
            
            with open(upload_file, 'rb') as audio_file:
                transcription = await openai_clients['stt'].audio.transcriptions.create(
                    model=STT_MODEL,
                    file=audio_file,
                    response_format="text"
                )
                
                logger.debug(f"STT API response type: {type(transcription)}")
                text = transcription.strip() if isinstance(transcription, str) else transcription.text.strip()
                
                if text:
                    logger.info(f"✓ STT result: '{text}'")
                    return text
                else:
                    logger.warning("STT returned empty text")
                    return None
                        
        except Exception as e:
            logger.error(f"STT failed: {e}")
            logger.error(f"STT config when error occurred - Model: {STT_MODEL}, Base URL: {STT_BASE_URL}")
            if hasattr(e, 'response'):
                logger.error(f"HTTP status: {e.response.status_code if hasattr(e.response, 'status_code') else 'unknown'}")
                logger.error(f"Response text: {e.response.text if hasattr(e.response, 'text') else 'unknown'}")
            return None
        finally:
            os.unlink(wav_file.name)
            if 'mp3_file' in locals() and os.path.exists(mp3_file.name):
                os.unlink(mp3_file.name)


def record_audio(duration: float) -> np.ndarray:
    """Record audio from microphone"""
    logger.info(f"🎤 Recording audio for {duration}s...")
    if DEBUG:
        try:
            devices = sd.query_devices()
            default_input = sd.default.device[0]
            logger.debug(f"Default input device: {default_input} - {devices[default_input]['name'] if default_input is not None else 'None'}")
            logger.debug(f"Recording config - Sample rate: {SAMPLE_RATE}Hz, Channels: {CHANNELS}, dtype: int16")
        except Exception as dev_e:
            logger.error(f"Error querying audio devices: {dev_e}")
    
    try:
        samples_to_record = int(duration * SAMPLE_RATE)
        logger.debug(f"Recording {samples_to_record} samples...")
        
        recording = sd.rec(
            samples_to_record,
            samplerate=SAMPLE_RATE,
            channels=CHANNELS,
            dtype=np.int16
        )
        sd.wait()
        
        flattened = recording.flatten()
        logger.info(f"✓ Recorded {len(flattened)} samples")
        
        if DEBUG:
            logger.debug(f"Recording stats - Min: {flattened.min()}, Max: {flattened.max()}, Mean: {flattened.mean():.2f}")
            # Check if recording contains actual audio (not silence)
            rms = np.sqrt(np.mean(flattened.astype(float) ** 2))
            logger.debug(f"RMS level: {rms:.2f} ({'likely silence' if rms < 100 else 'audio detected'})")
        
        return flattened
        
    except Exception as e:
        logger.error(f"Recording failed: {e}")
        logger.error(f"Audio config when error occurred - Sample rate: {SAMPLE_RATE}, Channels: {CHANNELS}")
        
        # Try to get more info about audio devices
        try:
            devices = sd.query_devices()
            logger.error(f"Available input devices:")
            for i, device in enumerate(devices):
                if device['max_input_channels'] > 0:
                    logger.error(f"  {i}: {device['name']} (inputs: {device['max_input_channels']})")
        except Exception as dev_e:
            logger.error(f"Cannot query audio devices: {dev_e}")
        
        return np.array([])


async def check_livekit_available() -> bool:
    """Check if LiveKit is available and has active rooms"""
    try:
        from livekit import api
        
        api_url = LIVEKIT_URL.replace("ws://", "http://").replace("wss://", "https://")
        lk_api = api.LiveKitAPI(api_url, LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
        
        rooms = await lk_api.room.list_rooms(api.ListRoomsRequest())
        active_rooms = [r for r in rooms.rooms if r.num_participants > 0]
        
        return len(active_rooms) > 0
        
    except Exception as e:
        logger.debug(f"LiveKit not available: {e}")
        return False


async def livekit_ask_voice_question(question: str, room_name: str = "", timeout: float = 60.0) -> str:
    """Ask voice question using LiveKit transport"""
    try:
        from livekit import rtc, api
        from livekit.agents import Agent, AgentSession
        from livekit.plugins import openai as lk_openai, silero
        
        # Auto-discover room if needed
        if not room_name:
            api_url = LIVEKIT_URL.replace("ws://", "http://").replace("wss://", "https://")
            lk_api = api.LiveKitAPI(api_url, LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
            
            rooms = await lk_api.room.list_rooms(api.ListRoomsRequest())
            for room in rooms.rooms:
                if room.num_participants > 0:
                    room_name = room.name
                    break
            
            if not room_name:
                return "No active LiveKit rooms found"
        
        # Setup TTS and STT for LiveKit
        tts_client = lk_openai.TTS(voice=TTS_VOICE, base_url=TTS_BASE_URL, model=TTS_MODEL)
        stt_client = lk_openai.STT(base_url=STT_BASE_URL, model=STT_MODEL)
        
        # Create simple agent that speaks and listens
        class VoiceAgent(Agent):
            def __init__(self):
                super().__init__(
                    instructions="Speak the message and listen for response",
                    stt=stt_client,
                    tts=tts_client,
                    llm=None
                )
                self.response = None
                self.has_spoken = False
            
            async def on_enter(self):
                await asyncio.sleep(0.5)
                if self.session:
                    await self.session.say(question, allow_interruptions=True)
                    self.has_spoken = True
            
            async def on_user_turn_completed(self, chat_ctx, new_message):
                if self.has_spoken and not self.response and new_message.content:
                    self.response = new_message.content[0]
        
        # Connect and run
        token = api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
        token.with_identity("voice-mcp-bot").with_name("Voice MCP Bot")
        token.with_grants(api.VideoGrants(
            room_join=True, room=room_name,
            can_publish=True, can_subscribe=True,
        ))
        
        room = rtc.Room()
        await room.connect(LIVEKIT_URL, token.to_jwt())
        
        if not room.remote_participants:
            await room.disconnect()
            return "No participants in LiveKit room"
        
        agent = VoiceAgent()
        vad = silero.VAD.load()
        session = AgentSession(vad=vad)
        await session.start(room=room, agent=agent)
        
        # Wait for response
        start_time = time.time()
        while time.time() - start_time < timeout:
            if agent.response:
                await room.disconnect()
                return agent.response
            await asyncio.sleep(0.1)
        
        await room.disconnect()
        return f"No response within {timeout}s"
        
    except Exception as e:
        logger.error(f"LiveKit error: {e}")
        return f"LiveKit error: {str(e)}"


@mcp.tool()
async def ask_voice_question(
    question: str,
    transport: Literal["auto", "local", "livekit"] = "auto",
    room_name: str = "",
    duration: float = 5.0,
    timeout: float = 60.0
) -> str:
    """Ask a voice question and get a voice response
    
    PRIVACY NOTICE: This tool will access your microphone to record audio
    for speech-to-text conversion. Audio is processed using the configured
    STT service and is not permanently stored.
    
    Args:
        question: The question to ask
        transport: Transport method - "auto" (try LiveKit then local), "local" (direct mic), "livekit" (room-based)
        room_name: LiveKit room name (only for livekit transport, auto-discovered if empty)
        duration: Recording duration in seconds (only for local transport)
        timeout: Maximum wait time for response in seconds
    """
    logger.info(f"Voice question: {question[:50]}... (transport: {transport})")
    
    # Determine transport method
    if transport == "auto":
        if await check_livekit_available():
            transport = "livekit"
            logger.info("Auto-selected LiveKit transport")
        else:
            transport = "local"
            logger.info("Auto-selected local transport")
    
    if transport == "livekit":
        return await livekit_ask_voice_question(question, room_name, timeout)
    
    elif transport == "local":
        # Local microphone approach
        try:
            # Speak the question
            tts_success = await text_to_speech(question)
            if not tts_success:
                return "Error: Could not speak question"
            
            # Brief pause
            await asyncio.sleep(0.5)
            
            # Record response
            logger.info(f"🎤 Listening for {duration} seconds...")
            audio_data = await asyncio.get_event_loop().run_in_executor(
                None, record_audio, duration
            )
            
            if len(audio_data) == 0:
                return "Error: Could not record audio"
            
            # Convert to text
            response_text = await speech_to_text(audio_data)
            
            if response_text:
                return f"Voice response: {response_text}"
            else:
                return "No speech detected"
                
        except Exception as e:
            logger.error(f"Local voice error: {e}")
            return f"Error: {str(e)}"
    
    else:
        return f"Unknown transport: {transport}"


@mcp.tool()
async def speak_text(text: str) -> str:
    """Convert text to speech and play it
    
    Args:
        text: Text to convert to speech
    """
    logger.info(f"Speaking: {text[:50]}...")
    
    try:
        success = await text_to_speech(text)
        return "✓ Text spoken successfully" if success else "✗ Failed to speak text"
    except Exception as e:
        logger.error(f"Speak error: {e}")
        return f"Error: {str(e)}"


@mcp.tool()
async def listen_for_speech(duration: float = 5.0) -> str:
    """Listen for speech and convert to text
    
    PRIVACY NOTICE: This tool will access your microphone to record audio
    for the specified duration. Audio is processed locally and converted
    to text using the configured STT service.
    
    Args:
        duration: How long to listen in seconds
    """
    logger.info(f"Listening for {duration}s...")
    
    try:
        logger.info(f"🎤 Recording for {duration} seconds...")
        
        audio_data = await asyncio.get_event_loop().run_in_executor(
            None, record_audio, duration
        )
        
        if len(audio_data) == 0:
            return "Error: Could not record audio"
        
        response_text = await speech_to_text(audio_data)
        
        if response_text:
            return f"Speech detected: {response_text}"
        else:
            return "No speech detected"
            
    except Exception as e:
        logger.error(f"Listen error: {e}")
        return f"Error: {str(e)}"


@mcp.tool()
async def check_room_status() -> str:
    """Check LiveKit room status and participants"""
    try:
        from livekit import api
        
        api_url = LIVEKIT_URL.replace("ws://", "http://").replace("wss://", "https://")
        lk_api = api.LiveKitAPI(api_url, LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
        
        rooms = await lk_api.room.list_rooms(api.ListRoomsRequest())
        if not rooms.rooms:
            return "No active LiveKit rooms"
        
        status = []
        for room in rooms.rooms:
            participants = await lk_api.room.list_participants(
                api.ListParticipantsRequest(room=room.name)
            )
            
            status.append(f"• {room.name} ({room.num_participants} participants)")
            for p in participants.participants:
                status.append(f"  - {p.identity}")
        
        return "\n".join(status)
        
    except Exception as e:
        logger.error(f"Room status error: {e}")
        return f"Error checking room status: {str(e)}"


@mcp.tool()
async def check_audio_devices() -> str:
    """List available audio input and output devices"""
    try:
        devices = sd.query_devices()
        input_devices = []
        output_devices = []
        
        for i, device in enumerate(devices):
            if device['max_input_channels'] > 0:
                input_devices.append(f"  {i}: {device['name']} (inputs: {device['max_input_channels']})")
            if device['max_output_channels'] > 0:
                output_devices.append(f"  {i}: {device['name']} (outputs: {device['max_output_channels']})")
        
        result = []
        result.append("🎤 Input Devices:")
        result.extend(input_devices if input_devices else ["  None found"])
        result.append("\n🔊 Output Devices:")
        result.extend(output_devices if output_devices else ["  None found"])
        
        default_input = sd.default.device[0] if sd.default.device[0] is not None else "None"
        default_output = sd.default.device[1] if sd.default.device[1] is not None else "None"
        
        result.append(f"\n📌 Default Input: {default_input}")
        result.append(f"📌 Default Output: {default_output}")
        
        return "\n".join(result)
        
    except Exception as e:
        return f"Error listing audio devices: {str(e)}"


if __name__ == "__main__":
    logger.info("Voice MCP Server - Unified Transport")
    logger.info(f"Debug mode: {'ENABLED' if DEBUG else 'DISABLED'}")
    if DEBUG:
        logger.info(f"Debug recordings will be saved to: {DEBUG_DIR}")
    logger.info(f"STT: {STT_BASE_URL} (model: {STT_MODEL})")
    logger.info(f"TTS: {TTS_BASE_URL} (model: {TTS_MODEL}, voice: {TTS_VOICE})")
    logger.info(f"LiveKit: {LIVEKIT_URL}")
    logger.info(f"Sample Rate: {SAMPLE_RATE}Hz")
    
    mcp.run()
