#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import logging
import os

try:
    import imaspy as imas
except ImportError:
    import imas
import numpy as np
from matplotlib.animation import FuncAnimation
from matplotlib.widgets import Slider
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.utils.clihelper import (
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.equilibrium import EquilibriumView
from idstools.view.wall import WallView

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="""---- Display the plasma equilibrium from the equilibrium IDS.
        It also shows pf coils and wall position overlay if exists [Previously known as equiplot]""",
        formatter_class=RichHelpFormatter,
        parents=[rcparam_parser],
    )
    parser.add_argument(
        "-u",
        "--uri",
        nargs="*",
        type=str,
        required=True,
        help="uri separated by spaces if comparing " '(e.g. "imas:hdf5?path=./testdb" "./testpulse.nc")' "\n",
    )
    parser.add_argument(
        "--dd-update",
        action="store_true",
        help=(
            "Convert IDS to the default version of the data dictionary if enabled"
            "otherwise, use the original IDS stored on disk."
        ),
    )
    parser.add_argument("-t", "--time", help="Time (default=middle)", type=float, default=-99.0)

    parser.add_argument(
        "--log-level",
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        default="INFO",
        help="Set the logging level (default=%(default)s)",
    )

    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    log_level = getattr(logging, args.log_level)
    logger = setup_logger("module", stdout_level=log_level)
    equilibrium1 = None
    equilibrium2 = None
    wall1 = None
    wall2 = None

    if len(args.uri) >= 1:
        connection1 = imas.DBEntry(args.uri[0], "r")
        if args.dd_update:
            equilibrium1 = connection1.get("equilibrium", autoconvert=False)
            equilibrium1 = imas.convert_ids(equilibrium1, connection1.factory.version)
        else:
            equilibrium1 = connection1.get("equilibrium", lazy=True, autoconvert=False)

        time_array1 = equilibrium1.time
        time_slice1, time_value1 = get_nearest_time(time_array1, args.time)

        try:
            if args.dd_update:
                wall1 = connection1.get("wall", autoconvert=False)
                wall1 = imas.convert_ids(wall1, connection1.factory.version)
            else:
                wall1 = connection1.get("wall", lazy=True, autoconvert=False)
        except Exception as e:
            logger.warning(f"Could not load wall data for first connection: {e}")
            wall1 = None

        if connection1 is None:
            logger.critical("----> Could not open first data entry. Aborted.")
            exit(1)
    if len(args.uri) == 2:
        connection2 = imas.DBEntry(args.uri[1], "r")
        if args.dd_update:
            equilibrium2 = connection2.get("equilibrium", autoconvert=False)
            equilibrium2 = imas.convert_ids(equilibrium2, connection2.factory.version)
        else:
            equilibrium2 = connection2.get("equilibrium", lazy=True, autoconvert=False)
        time_array2 = equilibrium2.time
        time_slice2 = np.argmin(abs(time_array2 - time_value1))
        time_value2 = time_array2[time_slice2]

        try:
            if args.dd_update:
                wall2 = connection2.get("wall", autoconvert=False)
                wall2 = imas.convert_ids(wall2, connection2.factory.version)
            else:
                wall2 = connection2.get("wall", lazy=True, autoconvert=False)
        except Exception as e:
            logger.warning(f"Could not load wall data for second connection: {e}")
            wall2 = None

        if connection2 is None:
            logger.critical("----> Could not open second data entry.")
            exit(1)

    canvas = PlotCanvas(1, 4)
    canvas.update_style(args.rc)

    canvas.fig.subplots_adjust(top=0.90, bottom=0.12, left=0.044, right=0.946, hspace=0.216, wspace=0.240)

    ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)
    ax2 = canvas.add_axes(title="", xlabel="", row=0, col=1)
    ax3 = canvas.add_axes(title="", xlabel="", row=0, col=2)
    ax4 = canvas.add_axes(title="", xlabel="", row=0, col=3)

    equiview = EquilibriumView(equilibrium1)

    current_time_slice1 = [time_slice1]

    ax_slider = canvas.fig.add_axes([0.15, 0.02, 0.65, 0.03])

    time_slider = Slider(
        ax_slider,
        "Time (s) (Press F1 for help)",
        0,
        len(time_array1) - 1,
        valinit=time_slice1,
        valstep=1,
        valfmt="%d",  # Show index number
    )

    time_slider.valtext.set_visible(False)

    time_text = canvas.fig.text(
        0.82,
        0.02,
        f"Index: {time_slice1}  |  Time: {time_array1[time_slice1]:.2f} s",
        ha="left",
        va="bottom",
        fontsize=10,
        color="red",
        transform=canvas.fig.transFigure,
    )

    animation_state = {"is_playing": False, "animation": None}

    text_artists = {"text1": None, "text2": None, "title": None, "help": None}

    help_visible = [False]

    wallview1 = WallView(wall1) if wall1 else None
    wallview2 = WallView(wall2) if wall2 else None

    def update_plots(time_idx1):
        """Update all plots for the given time index"""
        for ax in [ax1, ax2, ax3, ax4]:
            ax.clear()

        # Update time slice for comparison
        if equilibrium2:
            time_val1 = time_array1[time_idx1]
            time_idx2 = np.argmin(abs(time_array2 - time_val1))
            time_val2 = time_array2[time_idx2]

            equiview.view_profile_plot(ax1, time_idx1, equilibrium2_ids=equilibrium2, time_index2=time_idx2)
            equiview.view_equilibrium_plot(ax2, time_index1=time_idx1, equilibrium2_ids=equilibrium2)
            equiview.view_current_plot(ax3, time_index1=time_idx1, equilibrium2_ids=equilibrium2)
            equiview.view_constraints(ax4, time_index1=time_idx1, equilibrium2_ids=equilibrium2)
        else:
            time_val1 = time_array1[time_idx1]

            equiview.view_profile_plot(ax1, time_idx1)
            equiview.view_equilibrium_plot(ax2, time_index1=time_idx1)
            equiview.view_current_plot(ax3, time_index1=time_idx1)
            equiview.view_constraints(ax4, time_index1=time_idx1)

        # Redraw walls (use cached WallView objects)
        if wallview1:
            wallview1.view_inner_wall_line(ax2)
        if wallview2:
            wallview2.view_inner_wall_line(ax2)

        # Update title efficiently - remove old text artists
        if text_artists["text1"]:
            text_artists["text1"].remove()
            text_artists["text1"] = None
        if text_artists["text2"]:
            text_artists["text2"].remove()
            text_artists["text2"] = None
        if text_artists["title"]:
            text_artists["title"].remove()
            text_artists["title"] = None

        if len(args.uri) == 2:
            y = 0.99
            text_artists["text1"] = canvas.fig.text(
                0.5,
                y,
                args.uri[0] + f"(time {time_val1:.2f})",
                ha="center",
                va="top",
                color="tab:green",
                transform=canvas.fig.transFigure,
            )
            text_artists["text2"] = canvas.fig.text(
                0.5,
                y - 0.03,
                args.uri[1] + f"(time {time_val2:.2f})",
                ha="center",
                va="top",
                color="tab:blue",
                transform=canvas.fig.transFigure,
            )
        else:
            text_artists["title"] = canvas.fig.text(
                0.5,
                0.99,
                f"{args.uri[0]} time {time_val1:.2f}",
                ha="center",
                va="top",
                fontsize=12,
                transform=canvas.fig.transFigure,
            )

        canvas.fig.canvas.draw()
        canvas.fig.canvas.flush_events()

    def on_slider_change(val):
        """Handle slider value change"""
        time_idx = int(val)
        current_time_slice1[0] = time_idx
        time_text.set_text(f"Index: {time_idx}  |  Time: {time_array1[time_idx]:.2f} s")
        if not animation_state["is_playing"]:
            update_plots(time_idx)

    def animate(frame):
        """Animation function for play"""
        if animation_state["is_playing"]:
            current_idx = current_time_slice1[0]
            next_idx = (current_idx + 1) % len(time_array1)
            current_time_slice1[0] = next_idx
            time_slider.set_val(next_idx)
            update_plots(next_idx)

    def toggle_play_pause():
        """Toggle between play and pause"""
        if not animation_state["is_playing"]:
            # Start playing
            animation_state["is_playing"] = True
            logger.info("Animation playing (SPACEBAR to pause)")
            if animation_state["animation"] is None:
                animation_state["animation"] = FuncAnimation(
                    canvas.fig,
                    animate,
                    interval=50,
                    blit=False,
                    cache_frame_data=False,
                )
            canvas.fig.canvas.draw_idle()
        else:
            # Pause
            animation_state["is_playing"] = False
            logger.info("Animation paused (SPACEBAR to resume)")

    def toggle_help():
        """Toggle help overlay on/off"""
        if help_visible[0]:
            # Hide help
            if text_artists["help"] is not None:
                text_artists["help"].remove()
                text_artists["help"] = None
            help_visible[0] = False
            canvas.fig.canvas.draw()
        else:
            # Show help
            help_text = (
                "Keyboard Controls:\n\n"
                "SPACEBAR  - Play/Pause animation\n"
                "ESC       - Stop and reset to start\n"
                "← / →     - Step backward/forward one frame\n"
                "F1        - Show/Hide this help\n\n"
                "Slider:\n"
                "Click and drag to navigate to any time"
            )
            text_artists["help"] = canvas.fig.text(
                0.5,
                0.5,
                help_text,
                transform=canvas.fig.transFigure,
                ha="center",
                va="center",
                fontsize=12,
                bbox=dict(boxstyle="round,pad=1", facecolor="yellow", alpha=0.9, edgecolor="black", linewidth=2),
                zorder=1000,
            )
            help_visible[0] = True
            canvas.fig.canvas.draw()

    def on_key_press(event):
        """Handle keyboard events"""
        if event.key == " ":  # Spacebar
            toggle_play_pause()
        elif event.key == "f1":  # F1 - Toggle help
            toggle_help()
        elif event.key == "escape":  # ESC to stop and reset
            animation_state["is_playing"] = False
            current_time_slice1[0] = 0
            time_slider.set_val(0)
            update_plots(0)
            logger.info("Animation stopped and reset (SPACEBAR to play)")
        elif event.key == "left":  # Left arrow - previous frame
            animation_state["is_playing"] = False
            new_idx = max(0, current_time_slice1[0] - 1)
            current_time_slice1[0] = new_idx
            time_slider.set_val(new_idx)
        elif event.key == "right":  # Right arrow - next frame
            animation_state["is_playing"] = False
            new_idx = min(len(time_array1) - 1, current_time_slice1[0] + 1)
            current_time_slice1[0] = new_idx
            time_slider.set_val(new_idx)

    # Connect events
    time_slider.on_changed(on_slider_change)
    canvas.fig.canvas.mpl_connect("key_press_event", on_key_press)

    # Initial plot
    if equilibrium2:
        equiview.view_profile_plot(ax1, time_slice1, equilibrium2_ids=equilibrium2, time_index2=time_slice2)
        equiview.view_equilibrium_plot(ax2, time_index1=time_slice1, equilibrium2_ids=equilibrium2)
        equiview.view_current_plot(ax3, time_index1=time_slice1, equilibrium2_ids=equilibrium2)
        equiview.view_constraints(ax4, time_index1=time_slice1, equilibrium2_ids=equilibrium2)
    else:
        equiview.view_profile_plot(ax1, time_slice1)
        equiview.view_equilibrium_plot(ax2, time_index1=time_slice1)
        equiview.view_current_plot(ax3, time_index1=time_slice1)
        equiview.view_constraints(ax4, time_index1=time_slice1)

    if wallview1:
        wallview1.view_inner_wall_line(ax2)
    if wallview2:
        wallview2.view_inner_wall_line(ax2)

    # Set initial title using the text_artists dictionary to avoid overdrawing
    if len(args.uri) == 2:
        y = 0.99
        text_artists["text1"] = canvas.fig.text(
            0.5,
            y,
            args.uri[0] + f"(time {time_value1:.2f})",
            ha="center",
            va="top",
            color="tab:green",
            transform=canvas.fig.transFigure,
        )
        text_artists["text2"] = canvas.fig.text(
            0.5,
            y - 0.03,
            args.uri[1] + f"(time {time_value2:.2f})",
            ha="center",
            va="top",
            color="tab:blue",
            transform=canvas.fig.transFigure,
        )
    else:
        text_artists["title"] = canvas.fig.text(
            0.5,
            0.99,
            f"{args.uri[0]} time {time_value1:.2f}",
            ha="center",
            va="top",
            fontsize=12,
            transform=canvas.fig.transFigure,
        )

    canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))

    if args.save:
        fname = f"plotequicomp_time_{time_value1}.png"
        if args.directory:
            if not os.path.exists(args.directory):
                os.makedirs(args.directory)
            fname = os.path.join(args.directory, fname)
        canvas.save(fname)
    else:
        canvas.show()
