#!/usr/bin/env python3
import argparse
import logging
import os
from collections import OrderedDict

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.compute.waves import WavesCompute
from idstools.database import DBMaster
from idstools.utils.clihelper import (
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.distributions import DistributionsView
from idstools.view.waves import WavesView

logger = setup_logger("module", stdout_level=logging.INFO)


def _show_waves_plots(connargs, args, hold=False, dd_update=False, rc=""):
    conn = DBMaster.get_connection(connargs)
    if conn is None:
        logger.critical(f"data entry not found : {connargs}")
        return

    ids_waves = None
    try:
        if dd_update:
            ids_waves = imas.convert_ids(conn.get("waves", autoconvert=False), conn.factory.version)
        else:
            ids_waves = conn.get("waves", lazy=True, autoconvert=False)

    except Exception as e:
        logger.error(f"waves ids is not present, detailed error {e}")

    if ids_waves:
        waves_compute = WavesCompute(ids_waves)

        time_array = ids_waves.time
        ntime = len(time_array)
        time_slice, time_value = get_nearest_time(ids_waves.time, args["time"])

        radial_grid = waves_compute.get_radial_grid_info(time_slice, args["force_psi"])
        if radial_grid is None:
            logger.critical("Radial grid information is empty ----> Abort.")
            return
        active_launchers = {key: value for key, value in radial_grid.items() if value["is_active"] is True}
        len_active_launchers = len(active_launchers)
        if len_active_launchers == 0:
            logger.critical("The waves IDS appears empty ----> Abort.")
            return

        rows = 2
        if ntime == 1:
            logger.error("Only one time slice --> ECCD and ECRH waveforms not displayed")
            rows = 1
        canvas = PlotCanvas(rows, 2)
        canvas.update_style(rc)
        # canvas.setStyle(style="retro")
        canvas.set_sup_title(f"HCD Waves Plot {connargs.uri} Time : {time_value:.3f}")

        ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)
        ax2 = canvas.add_axes(title="", xlabel="", row=0, col=1)
        if ntime != 1:
            ax3 = canvas.add_axes(title="", xlabel="", row=1, col=0)
            ax4 = canvas.add_axes(title="", xlabel="", row=1, col=1)

        waves_view = WavesView(ids_waves)
        waves_view.view_absorbed_power_density_profile(ax1, time_slice, args["force_psi"])
        if args["hide_legend"]:
            ax1.get_legend().remove()
        else:
            ax1.get_legend().remove()
        waves_view.view_c_d_profile(ax2, time_slice, args["force_psi"])
        if args["hide_legend"]:
            ax2.get_legend().remove()
        if ntime != 1:
            waves_view.view_e_c_power_waveform(ax3, time_slice, args["force_psi"])
            if args["hide_legend"]:
                ax3.get_legend().remove()
            else:
                ax3.get_legend().remove()
            waves_view.view_c_d_waveform(ax4, time_slice, args["force_psi"])
            if args["hide_legend"]:
                ax4.get_legend().remove()
            else:
                ax4.get_legend().remove()

        canvas.set_text(text=f"{get_database_path(connargs, time_value=time_value)}")

        canvas.fig.subplots_adjust(
            top=0.92,
            bottom=0.122,
            left=0.05,
            right=0.896,
            hspace=0.2,
            wspace=0.13,
        )
        canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
        canvas.fig.suptitle(get_title(connargs, "HCD Waves Plot", time_value))
        if args["save"]:
            fname = get_file_name(connargs, "hcd_waves_plot", time_value)
            canvas.save(fname)
        else:
            canvas.show()
    conn.close()


def show_waves_plots(args, hold=False, dd_update=False, rc=""):
    pulse = args["pulse"]
    run = args["run"]

    # To handle multiple datafiles (for scans)
    if "-" in run:
        [runmin, runmax] = [int(x) for x in run.split("-")]  # int(run.split('-'))
    else:
        runmin = int(run)
        runmax = int(run)

    for irun in range(runmin, runmax + 1):
        logger.info(f"Open pulse: {pulse}/{irun} in DB {args['user']}/{args['database']}/{args['version']}")

        connargs = argparse.Namespace()
        connargs.backend = args["backend"]
        connargs.pulse = int(pulse)
        connargs.run = int(irun)
        connargs.user = args["user"]
        connargs.database = args["database"]
        connargs.version = args["version"]
        connargs.uri = (
            f"imas:{args['backend'].lower()}?user={connargs.user};pulse={connargs.pulse};"
            f"run={connargs.run};database={connargs.database};version={connargs.version}"
        )
        _show_waves_plots(connargs, args, hold, dd_update, rc)


def _show_distribution_plots(connargs, args, hold=False, dd_update=False, rc=""):
    conn = DBMaster.get_connection(connargs)
    if conn is None:
        logger.critical(f"data entry not found : {connargs}")
        return

    ids_distributions = None
    try:
        if dd_update:
            ids_distributions = imas.convert_ids(conn.get("distributions", autoconvert=False), conn.factory.version)
        else:
            ids_distributions = conn.get("distributions", lazy=True, autoconvert=False)

    except Exception as e:
        logger.error(f"distributions ids is not present, detailed error {e}")

    if ids_distributions:
        time_array = ids_distributions.time
        ntime = len(time_array)
        time_slice, time_value = get_nearest_time(ids_distributions.time, args["time"])

        canvas = PlotCanvas(3, 2)
        canvas.update_style(rc)
        # canvas.setStyle(style="retro")
        canvas.set_sup_title(f"HCD Distributions Plot {connargs.uri} Time : {time_value:.3f}")
        if ntime == 1:
            logger.info("Only one time slice --> Power and CD waveforms not displayed")
        ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)
        ax2 = canvas.add_axes(title="", xlabel="", row=1, col=0)
        ax3 = canvas.add_axes(title="", xlabel="", row=2, col=0)
        distributions_view = DistributionsView(ids_distributions)
        distributions_view.plot_absorbed_power_density_individual(ax1, time_slice)
        distributions_view.plot_absorbed_power_density(ax2, time_slice)
        distributions_view.plot_cd_profile(ax3, time_slice)
        if ntime != 1:
            ax4 = canvas.add_axes(title="", xlabel="", row=0, col=1)
            ax5 = canvas.add_axes(title="", xlabel="", row=1, col=1)
            distributions_view.plot_nbi_fus_power_and_cd_waveforms(ax4, time_slice)
            distributions_view.plot_cd_waveform(ax5, time_slice)

        canvas.set_text(text=f"{get_database_path(connargs, time_value=time_value)}")

        canvas.fig.subplots_adjust(
            top=0.92,
            bottom=0.122,
            left=0.044,
            right=0.886,
            hspace=0.438,
            wspace=0.328,
        )
        canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
        canvas.fig.suptitle(get_title(connargs, "HCD Distributions Plot", time_value))
        if args["save"]:
            fname = get_file_name(connargs, os.path.basename(__file__) + "_Distributions_profile_time", time_value)
            canvas.save(fname)
        else:
            canvas.show()
    conn.close()


def show_distribution_plots(args, hold=False, dd_update=False, rc=""):
    pulse = args["pulse"]
    run = args["run"]

    # To handle multiple datafiles (for scans)
    if "-" in run:
        [runmin, runmax] = [int(x) for x in run.split("-")]  # int(run.split('-'))
    else:
        runmin = int(run)
        runmax = int(run)

    for irun in range(runmin, runmax + 1):
        logger.info(f"Open pulse: {pulse}/{irun} in DB {args['user']}/{args['database']}/{args['version']}")
        connargs = argparse.Namespace()
        connargs.backend = args["backend"]
        connargs.pulse = int(pulse)
        connargs.run = int(irun)
        connargs.user = args["user"]
        connargs.database = args["database"]
        connargs.version = args["version"]
        connargs.uri = (
            f"imas:{args['backend'].lower()}?user={connargs.user};pulse={connargs.pulse};"
            f"run={connargs.run};database={connargs.database};version={connargs.version}"
        )
        _show_distribution_plots(connargs, args, hold, dd_update, rc)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="--- Display H&CD results from multiple data sets \n"
        + "- Each H&CD source is optional \n"
        + "- Time is also optional \n"
        + "- Example: \n"
        + "  plothcd -ech 134713/200/schneim/MDSPLUS/TORBEAM/3 -nbi 130012/122/schneim/nf_fopla_synergy \n"
        + "  plothcd -ech imas:mdsplus?user=public;pulse=134173;run=101;database=TEST;version=3 \n",
        formatter_class=RichHelpFormatter,
        parents=[rcparam_parser],
    )
    parser.add_argument(
        "-ech",
        help="uri(e.g. imas:hdf5?path=./testdb or ./testpulse.nc) or "
        "pulse/run/user/backend/database/[time] for ECH results (waves)",
        required=False,
        type=str,
    )
    parser.add_argument(
        "-icrh",
        help="uri(e.g. imas:hdf5?path=./testdb or ./testpulse.nc) or "
        "pulse/run/user/backend/database/version/[time] for ICRH results (distributions)",
        required=False,
        type=str,
    )
    parser.add_argument(
        "-nbi",
        help="uri(e.g. imas:hdf5?path=./testdb or ./testpulse.nc) or "
        "pulse/run/user/backend/database/version/[time] for NBI results (distributions)",
        required=False,
        type=str,
    )
    parser.add_argument(
        "-fus",
        help="uri(e.g. imas:hdf5?path=./testdb or ./testpulse.nc) or "
        "pulse/run/user/backend/database/version/[time] for fusion products (distributions)",
        required=False,
        type=str,
    )
    parser.add_argument(
        "--dd-update",
        action="store_true",
        help=(
            "Convert IDS to the default version of the data dictionary if enabled"
            "otherwise, use the original IDS stored on disk."
        ),
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    parser.add_argument("-t", "--time", help="Time", required=False, type=float, default=-99.0)
    args = vars(parser.parse_args())

    holder = OrderedDict()
    if args["ech"] is not None:
        if "imas:" in args["ech"]:
            holder["ech_uri_hold"] = True
    else:
        holder["ech_uri_hold"] = False

    if args["icrh"] is not None:
        if "imas:" in args["icrh"]:
            holder["icrh_uri_hold"] = True
    else:
        holder["icrh_uri_hold"] = False
    if args["nbi"] is not None:
        if "imas:" in args["nbi"]:
            holder["nbi_uri_hold"] = True
    else:
        holder["nbi_uri_hold"] = False
    if args["fus"] is not None:
        if "imas:" in args["fus"]:
            holder["fus_uri_hold"] = True
    else:
        holder["fus_uri_hold"] = False
    if args["ech"] is not None:
        if "imas:" not in args["ech"]:
            holder["ech_hold"] = True
    else:
        holder["ech_hold"] = False
    if args["icrh"] is not None:
        if "imas:" not in args["icrh"]:
            holder["icrh_hold"] = True
    else:
        holder["icrh_hold"] = False
    if args["nbi"] is not None:
        if "imas:" not in args["nbi"]:
            holder["nbi_hold"] = True
    else:
        holder["nbi_hold"] = False
    if args["fus"] is not None:
        if "imas:" not in args["fus"]:
            holder["fus_hold"] = True
    else:
        holder["fus_hold"] = False

    last_item = next(reversed(holder))

    for key, _ in holder.items():
        if key is not last_item:
            holder[key] = False

    if args["ech"] is not None and "imas:" in args["ech"]:
        connargs_ech = argparse.Namespace()
        connargs_ech.uri = args["ech"]
        args["force_psi"] = None
        args["hide_legend"] = False
        args["save"] = args["save"]
        _show_waves_plots(connargs_ech, args, holder["ech_uri_hold"], args["dd_update"])
    if args["icrh"] is not None and "imas:" in args["icrh"]:
        connargs_icrh = argparse.Namespace()
        connargs_icrh.uri = args["icrh"]
        args["save"] = args["save"]
        _show_distribution_plots(connargs_icrh, args, holder["icrh_uri_hold"], args["dd_update"])
    if args["nbi"] is not None and "imas:" in args["nbi"]:
        connargs_nbi = argparse.Namespace()
        connargs_nbi.uri = args["nbi"]
        args["save"] = args["save"]
        _show_distribution_plots(connargs_nbi, args, holder["nbi_uri_hold"], args["dd_update"])
    if args["fus"] is not None and "imas:" in args["fus"]:
        connargs_fus = argparse.Namespace()
        connargs_fus.uri = args["fus"]
        args["save"] = args["save"]
        _show_distribution_plots(connargs_fus, args, holder["fus_uri_hold"], args["dd_update"])

    hcd_sources = {
        "ech": {
            "input": "",
            "pulse": 0,
            "run": "",
            "backend": "MDSPLUS",
            "user": "public",
            "database": "ITER",
            "version": "3",
        },
        "icrh": {
            "input": "",
            "pulse": 0,
            "run": "",
            "backend": "MDSPLUS",
            "user": "public",
            "database": "ITER",
            "version": "3",
        },
        "nbi": {
            "input": "",
            "pulse": 0,
            "run": "",
            "backend": "MDSPLUS",
            "user": "public",
            "database": "ITER",
            "version": "3",
        },
        "fus": {
            "input": "",
            "pulse": 0,
            "run": "",
            "backend": "MDSPLUS",
            "user": "public",
            "database": "ITER",
            "version": "3",
        },
    }
    for source in hcd_sources.keys():
        if args[source] is not None:
            if "imas:" not in args[source]:
                hcd_sources[source]["input"] = args[source]
                if len(hcd_sources[source]["input"].split("/")) == 6:
                    (
                        hcd_sources[source]["pulse"],
                        hcd_sources[source]["run"],
                        hcd_sources[source]["user"],
                        hcd_sources[source]["backend"],
                        hcd_sources[source]["database"],
                        hcd_sources[source]["version"],
                    ) = hcd_sources[source]["input"].split("/")
                    hcd_sources[source]["time"] = -99.0
                elif len(hcd_sources[source]["input"].split("/")) == 7:
                    (
                        hcd_sources[source]["pulse"],
                        hcd_sources[source]["run"],
                        hcd_sources[source]["user"],
                        hcd_sources[source]["backend"],
                        hcd_sources[source]["database"],
                        hcd_sources[source]["version"],
                        hcd_sources[source]["time"],
                    ) = hcd_sources[source]["input"].split("/")
                else:
                    logger.critical("------------------------------------------------------------")
                    logger.critical("Bad input format: " + hcd_sources[source]["input"] + " not valid")
                    logger.critical("Arguments should be formatted like pulse/run/user/database")
                    logger.critical("--> Exit.")
                    logger.critical("-------------------------------------------------------------")
                    exit()
                hcd_sources[source]["pulse"] = int(hcd_sources[source]["pulse"])
                del hcd_sources[source]["input"]

    nsource = 0
    for source, value in hcd_sources.items():
        if value is not None:
            nsource = nsource + 1

    if nsource == 0:
        logger.warning("No H&CD source to display")
        logger.warning("--> Exit.")

    # ECH
    hold = False
    if hcd_sources["ech"]["run"]:
        ec_args = hcd_sources["ech"]
        ec_args["force_psi"] = None
        ec_args["hide_legend"] = False
        ec_args["save"] = args["save"]
        show_waves_plots(ec_args, hold=holder["ech_hold"], dd_update=args["dd_update"], rc=args["rc"])

    # NUCLEAR REACTIONS
    if hcd_sources["fus"]["run"]:
        fus_args = hcd_sources["fus"]
        fus_args["save"] = args["save"]
        show_distribution_plots(fus_args, hold=holder["icrh_hold"], dd_update=args["dd_update"], rc=args["rc"])

    # NBI
    if hcd_sources["nbi"]["run"]:
        nbi_args = hcd_sources["nbi"]
        nbi_args["save"] = args["save"]
        show_distribution_plots(nbi_args, hold=holder["nbi_hold"], dd_update=args["dd_update"], rc=args["rc"])

    # ICRH
    if hcd_sources["icrh"]["run"]:
        icrh_args = hcd_sources["icrh"]
        icrh_args["save"] = args["save"]
        show_distribution_plots(icrh_args, hold=holder["fus_hold"], dd_update=args["dd_update"], rc=args["rc"])
