#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.compute.waves import WavesCompute
from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.waves import WavesView

logger = setup_logger("module", stdout_level=logging.INFO)


def show_plots(args):
    args = parser.parse_args()
    conn = DBMaster.get_connection(args)
    if conn is None:
        logger.critical("----> Aborted.")
        exit(1)

    ids_waves = None
    try:

        if args.dd_update:
            ids_waves = conn.get("waves", autoconvert=False)
            ids_waves = imas.convert_ids(ids_waves, conn.factory.version)
        else:
            ids_waves = conn.get("waves", lazy=True, autoconvert=False)
    except Exception as e:
        logger.error(f"waves ids is not present, detailed error: {e}")

    if ids_waves:
        waves_compute = WavesCompute(ids_waves)

        time_array = ids_waves.time
        ntime = len(time_array)
        time_slice, time_value = get_nearest_time(ids_waves.time, args.time)

        radial_grid = waves_compute.get_radial_grid_info(time_slice, args.force_psi)
        if radial_grid is None:
            logger.critical("Radial grid information is empty ----> Abort.")
            exit(1)
        active_launchers = {key: value for key, value in radial_grid.items() if value["is_active"] is True}
        len_active_launchers = len(active_launchers)
        if len_active_launchers == 0:
            logger.critical("The waves IDS appears empty ----> Abort.")
            exit(1)

        rows = 2
        if ntime == 1:
            logger.error("Only one time slice --> ECCD and ECRH waveforms not displayed")
            rows = 1
        canvas = PlotCanvas(rows, 2)
        canvas.update_style(args.rc)
        # canvas.setStyle(style="retro")
        ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)
        ax2 = canvas.add_axes(title="", xlabel="", row=0, col=1)
        if ntime != 1:
            ax3 = canvas.add_axes(title="", xlabel="", row=1, col=0)
            ax4 = canvas.add_axes(title="", xlabel="", row=1, col=1)

        waves_view = WavesView(ids_waves)
        waves_view.view_absorbed_power_density_profile(ax1, time_slice, args.force_psi)
        if args.hide_legend:
            ax1.get_legend().remove()
        else:
            ax1.get_legend().remove()
        waves_view.view_c_d_profile(ax2, time_slice, args.force_psi)
        if args.hide_legend:
            ax2.get_legend().remove()
        if ntime != 1:
            waves_view.view_e_c_power_waveform(ax3, time_slice, args.force_psi)
            if args.hide_legend:
                ax3.get_legend().remove()
            else:
                ax3.get_legend().remove()
            waves_view.view_c_d_waveform(ax4, time_slice, args.force_psi)
            if args.hide_legend:
                ax4.get_legend().remove()
            else:
                ax4.get_legend().remove()

        canvas.set_text(text=f"{get_database_path(args, time_value=time_value)}")

        canvas.fig.subplots_adjust(
            top=0.92,
            bottom=0.122,
            left=0.05,
            right=0.896,
            hspace=0.2,
            wspace=0.13,
        )
        canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
        canvas.fig.suptitle(get_title(args, "HCD Waves Plot", time_value))

        if args.save:
            fname = get_file_name(args, os.path.basename(__file__) + "_heating_profiles_time", time_value)
            if args.directory:
                if not os.path.exists(args.directory):
                    os.makedirs(args.directory)
                fname = os.path.join(args.directory, fname)
            canvas.save(fname)
        else:
            canvas.show()
    conn.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="---- Display EC results",
        parents=[dbentry_parser, rcparam_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument("-t", "--time", help="Time", required=False, type=float, default=-99.0)
    parser.add_argument(
        "-f",
        "--force_psi",
        help="force displaying the profiles versus poloidal flux",
        action="store_true",
    )
    parser.add_argument("-l", "--hide_legend", help="remove the legend from graphs", action="store_true")
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()
    show_plots(args)


# def showPlots(args):
#     pulse = args.pulse
#     run = args.run

#     # To handle multiple datafiles (for scans)
#     if "-" in run:
#         [runmin, runmax] = [int(x) for x in run.split("-")]  # int(run.split('-'))
#     else:
#         runmin = int(run)
#         runmax = int(run)

#     for irun in range(runmin, runmax + 1):
#         logger.info(
#             f"Open pulse: {pulse}/{irun} in DB {args.user}/{args.database}/{args.version}"
#         )
