#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Example of call:
# ecray -p [130012,130012] -r [24,23] -u schneim -d iter

import argparse
import copy
import logging
import os

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.compute.waves import WavesCompute
from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idshelper import get_available_ids_and_times
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.domain.ecstray import EcStrayView
from idstools.view.equilibrium import EquilibriumView
from idstools.view.pf_active import PFActiveView
from idstools.view.wall import WallView
from idstools.view.waves import WavesView


class MdAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        if len(values) == 0:
            setattr(namespace, self.dest, True)
        else:
            setattr(namespace, self.dest, values)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="---- Display EC wave ray-tracing results [Previously known as ecray]",
        parents=[dbentry_parser, rcparam_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument(
        "-md",
        "--md",
        nargs="*",
        action=MdAction,
        default=[],
        help="""Provide machine descriptions that you need to plot\n
    with ids names for example wall pf_active \n
    or with uris for example\n
    "imas:mdsplus?user=public;shot=116000;run=4;database=ITER_MD;version=3#wall"\n
    "imas:mdsplus?user=public;shot=116000;run=4;database=ITER_MD;version=3"\n
    """,
    )
    parser.add_argument("-t", "--time", help="Time", required=False, type=float, default=-99.0)

    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    logger = setup_logger("module", stdout_level=logging.INFO)

    # Font/Colour definition
    bndcolor = "chocolate"
    shotcolors = ["b", "r", "c", "y", "m", "b"]
    shotstyle = ["-", "--", "-.", ":", ".", ","]
    colorcounter = 0
    lpad = -1

    canvas = PlotCanvas(2, 3)
    canvas.update_style(args.rc)
    ax_polview = canvas.add_axes(title="", xlabel="", row=0, col=0, rowspan=2)
    ax_topview = canvas.add_axes(title="", xlabel="", row=0, col=1, rowspan=2)
    ax_powview = canvas.add_axes(title="", xlabel="", row=0, col=2)
    ax_powparview = canvas.add_axes(title="", xlabel="", row=1, col=2)

    # Looping over shots and runs
    global_label = ""
    # Color management
    color = shotcolors[colorcounter % len(shotcolors)]
    style = shotstyle[colorcounter % len(shotcolors)]
    colorcounter += 1

    conn = DBMaster.get_connection(args)
    if conn is None:
        logger.critical("----> Aborted.")
        exit(1)

    ids_waves = None
    try:

        if args.dd_update:
            ids_waves = conn.get("waves", autoconvert=False)
            ids_waves = imas.convert_ids(ids_waves, conn.factory.version)
        else:
            ids_waves = conn.get("waves", lazy=True, autoconvert=False)
    except Exception as e:
        logger.error(f"waves ids is not present {e}")
        exit(1)

    ids_core_profiles = None
    try:

        if args.dd_update:
            ids_core_profiles = conn.get("core_profiles", autoconvert=False)
            ids_core_profiles = imas.convert_ids(ids_core_profiles, conn.factory.version)
        else:
            ids_core_profiles = conn.get("core_profiles", lazy=True, autoconvert=False)
    except Exception as e:
        logger.error(f"core_profiles ids is not present : {e}")
        exit(1)

    ids_equilibrium = None
    try:
        if args.dd_update:
            ids_equilibrium = conn.get("equilibrium", autoconvert=False)
            ids_equilibrium = imas.convert_ids(ids_equilibrium, conn.factory.version)
        else:
            ids_equilibrium = conn.get("equilibrium", lazy=True, autoconvert=False)
    except Exception as e:
        logger.error(f"equilibrium ids is not present : {e}")
        exit(1)

    # Search for adequate time slice for display
    time_array = ids_waves.time
    ntime = len(ids_waves.time)
    time_slice, time_value = get_nearest_time(time_array, args.time)

    if len(ids_waves.code.name) > 0:
        logger.info(f"Code name = {ids_waves.code.name.upper()}")
        label_code = ids_waves.code.name.upper()
    else:
        label_code = get_title(args)
    ecstra_view = EcStrayView(
        equilibrium_ids=ids_equilibrium,
        core_profiles_ids=ids_core_profiles,
        waves_ids=ids_waves,
    )
    equi_view = EquilibriumView(ids_equilibrium)
    wave_view = WavesView(ids_waves)
    wave_compute = WavesCompute(ids_waves)

    beam_tracing_dict = wave_compute.get_beam_tracing(time_slice)
    logger.info(
        f"There are {beam_tracing_dict['active_beams_count']} active beam(s)"
        f"and each beam has {beam_tracing_dict['max_total_beams']} ray(s)"
    )

    ecstra_view.plot_poloidal_view(ax_polview, coherent_wave_index=0, time_slice=time_slice)
    if args.md is True:
        args.md = ["wall", "pf_active"]

    for mdids_info in args.md:
        inputargs = copy.deepcopy(args)
        splittedms_ids_info = mdids_info.split("#")
        _ids_name = []
        _md_uri = None
        if len(splittedms_ids_info) == 2:
            _md_uri, ids_name = splittedms_ids_info
            _ids_name.append(ids_name)
            mdargs = argparse.Namespace()
            mdargs.uri = _md_uri
            inputargs = mdargs
        if len(splittedms_ids_info) == 1:
            if "imas:" in splittedms_ids_info[0]:
                _md_uri = splittedms_ids_info[0]
                mdargs = argparse.Namespace()
                mdargs.uri = _md_uri
                inputargs = mdargs
            else:
                _ids_name = splittedms_ids_info
        if len(_ids_name) == 0:
            ids_connection = DBMaster.get_connection(inputargs)
            available_ids_and_times = get_available_ids_and_times(ids_connection)
            for ids_name, _ in available_ids_and_times:
                name = ids_name.split("/")[0]
                _ids_name.append(name)
            ids_connection.close()
        if "wall" in _ids_name:
            connection = DBMaster.get_connection(inputargs)
            if connection is not None:
                try:

                    if args.dd_update:
                        _ids_data = connection.get("wall", autoconvert=False)
                        _ids_data = imas.convert_ids(_ids_data, connection.factory.version)
                    else:
                        _ids_data = connection.get("wall", lazy=True, autoconvert=False)
                    if _ids_data.ids_properties.homogeneous_time != imas.ids_defs.EMPTY_INT:
                        wallview = WallView(_ids_data)
                        wallview.view_wall(
                            ax_polview,
                            wallcolor="slategray",
                        )
                    else:
                        logger.warning(f"{_ids_name} ids is empty in the dbentry")
                except Exception as e:
                    logger.warning(f"{_ids_name} ids is empty in the dbentry {e}")
                connection.close()
            else:
                logger.warning(f"{_ids_name} ids is not present")
        if "pf_active" in _ids_name:
            connection = DBMaster.get_connection(inputargs)
            if connection is not None:
                try:

                    if args.dd_update:
                        _ids_data = connection.get("pf_active", autoconvert=False)
                        _ids_data = imas.convert_ids(_ids_data, connection.factory.version)
                    else:
                        _ids_data = connection.get("pf_active", lazy=True, autoconvert=False)
                    if _ids_data.ids_properties.homogeneous_time != imas.ids_defs.EMPTY_INT:
                        view_pf_coils = PFActiveView(_ids_data)
                        view_pf_coils.view_active_pf_coils(ax_polview)
                    else:
                        logger.warning(f"{_ids_name} ids is empty in the dbentry")
                except Exception as e:
                    logger.warning(f"{_ids_name} ids is empty in the dbentry {e}")
                connection.close()

            else:
                logger.warning("f{_idsName} ids is not present")

    wave_view.plot_pol_view_traces(
        ax_polview,
        time_slice,
        color=color,
        style=style,
    )

    equi_view.plot_topplotequilibrium(ax_topview, time_slice)
    wave_view.plot_top_view_traces(ax_topview, time_slice, color=color, style=style, label=label_code)

    wave_view.plot_electron_power(ax_powview, time_slice, color=color, style=style)
    wave_view.plot_power_flow_normal(ax_powparview, time_slice, color=color, style=style)

    canvas.set_text(text=f"{get_database_path(args, time_value=time_value)}")

    canvas.fig.subplots_adjust(top=0.95, bottom=0.097, left=0, right=0.948, hspace=0.2, wspace=0.108)
    canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
    canvas.fig.suptitle(get_title(args, "EC rays", time_value))

    if args.save:
        fname = get_file_name(args, os.path.basename(__file__) + "_EC_rays", time_value)
        if args.directory:
            if not os.path.exists(args.directory):
                os.makedirs(args.directory)
            fname = os.path.join(args.directory, fname)
        canvas.save(fname)
    else:
        canvas.show()

    # def StrToList(string):
    #     if string[0] == "[":
    #         slist = string.split(",")
    #         slist[0] = slist[0][1:]
    #         slist[-1] = slist[-1][:-1]
    #         slist = [int(slist[i]) for i in range(len(slist))]
    #         return slist
    #     else:
    #         return [int(string)]

    # pulseString, runString = args.pulse, args.run
    # pulseList, runList = StrToList(pulseString), StrToList(runString)
    # assert len(pulseList) <= len(
    #     runList
    # ), "There must be a run number corresponding to each pulse"
    # if len(runList) > len(pulseList):
    #     assert (
    #         len(pulseList) == 1
    #     ), "Either give a pulse for each run or give a single pulse for every run"
    #     pulse = len(runList) * pulseList

    # for i in range(len(pulseList)):
    #     ipulse = pulseList[i]
    #     irun = runList[i]
