#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.waves import WavesView

if __name__ == "__main__":
    # MANAGEMENT OF INPUT ARGUMENTS
    # ------------------------------
    parser = argparse.ArgumentParser(
        description="---- Display EC results [previously known as eccomp]",
        parents=[dbentry_parser, rcparam_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument(
        "-t",
        "--time",
        help="Time for which profiles are displayed",
        type=float,
        default=-99.0,
    )
    parser.add_argument(
        "-f",
        "--force_psi",
        help="= 1 to force displaying the profiles versus poloidal flux",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--verbose",
        help="= 1 to display numerical analysis of gaussian profiles",
        action="store_true",
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    logger = setup_logger("module", stdout_level=logging.INFO)

    verbose = args.verbose
    # Flag to force psi radial coordinate to be used
    force_psi = int(args.force_psi)

    conn = DBMaster.get_connection(args)
    if conn is None:
        logger.critical("----> Aborted.")
        exit(1)

    ids_waves = None
    try:

        if args.dd_update:
            ids_waves = conn.get("waves", autoconvert=False)
            core_sources = imas.convert_ids(ids_waves, conn.factory.version)
        else:
            ids_waves = conn.get("waves", lazy=True, autoconvert=False)
    except Exception as e:
        logger.error(f"waves ids is not present, detailed error: {e}")
        exit(1)

    # TIME VECTOR AND TIME INDEX
    # ---------------------------
    time_array = ids_waves.time

    ntime = len(time_array)
    time_slice, time_value = get_nearest_time(time_array, args.time)
    time = time_value
    rows = 2
    if ntime == 1:
        logger.error("Only one time slice --> ECCD and ECRH waveforms not displayed")
        rows = 1
    canvas = PlotCanvas(rows, 2)
    canvas.update_style(args.rc)
    ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)
    ax2 = canvas.add_axes(title="", xlabel="", row=0, col=1)
    if ntime != 1:
        ax3 = canvas.add_axes(title="", xlabel="", row=1, col=0)
        ax4 = canvas.add_axes(title="", xlabel="", row=1, col=1)

    waves_view = WavesView(ids_waves)

    waves_view.display_e_c_launchers_info(time_slice)
    waves_view.plot_ecrh_profiles(ax1, time_slice, verbose=args.verbose)
    waves_view.plot_eccd_profiles(ax2, time_slice, verbose=args.verbose)
    if ntime != 1:
        waves_view.plot_ecrh_waveform(ax3, time_slice)
        waves_view.plot_e_c_c_d_waveform(ax4, time_slice)

    canvas.set_text(text=f"{get_database_path(args, time_value)}")
    canvas.fig.suptitle(get_title(args, "EC Composition", time_value))
    canvas.fig.subplots_adjust(top=0.941, bottom=0.122, left=0.052, right=0.925, hspace=0.2, wspace=0.2)
    canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))

    if args.save:
        fname = get_file_name(args, os.path.basename(__file__) + "_EC_Composition", time_value)
        if args.directory:
            if not os.path.exists(args.directory):
                os.makedirs(args.directory)
            fname = os.path.join(args.directory, fname)
        canvas.save(fname)
    else:
        canvas.show()
    conn.close()
    # def StrToList(string):
    #     if string[0] == "[":
    #         slist = string.split(",")
    #         slist[0] = slist[0][1:]
    #         slist[-1] = slist[-1][:-1]
    #         slist = [int(slist[i]) for i in range(len(slist))]
    #         return slist
    #     else:
    #         return [int(string)]
    # pulseString, runString = args.pulse, args.run
    # pulseList, runList = StrToList(pulseString), StrToList(runString)
    # assert len(pulseList) <= len(
    #     runList
    # ), "There must be a run number corresponding to each pulse"

    # if len(runList) > len(pulseList):
    #     assert (
    #         len(pulseList) == 1
    #     ), "Either give a pulse for each run or give a single pulse for every run"
    #     pulseList = len(runList) * pulseList
    # nsimu = len(pulseList)
    # for isimu in range(len(pulseList)):
    #     ipulse = pulseList[isimu]
    #     irun = runList[isimu]
    #     # Reading in data from run in pulse
    #     logger.info(
    #         f"Open pulse: {ipulse}/{irun} in DB {args.user}/{args.database}/{args.version}"
    #     )
