#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# python scripts/plotequilibrium -p 134174 -r 117
# -md "imas:mdsplus?user=public;shot=116000;run=2;database=ITER_MD;version=3#wall"
# "imas:mdsplus?user=public;shot=111001;run=102;database=ITER_MD;version=3#pf_active"

import argparse
import copy
import logging
import os

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.compute.equilibrium import EquilibriumCompute
from idstools.database import DBMaster
from idstools.machinedescription import get_md_data
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.domain.mdplot import plot_machine_description
from idstools.view.equilibrium import EquilibriumView


class MdAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        if len(values) == 0:
            setattr(namespace, self.dest, True)
        else:
            setattr(namespace, self.dest, values)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="""---- Display the plasma equilibrium from the equilibrium IDS.
        It also shows pf coils and wall position overlay if exists [Previously known as equiplot]""",
        parents=[dbentry_parser, rcparam_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument("-t", "--time", help="Time (default=middle)", type=float, default=-99.0)
    parser.add_argument(
        "--rho",
        help="Show rho overlay on the plot",
        action="store_true",
    )
    parser.add_argument(
        "-p",
        "--plots",
        help="Plots available quantities along with equilibrium",
        action="store_true",
    )
    parser.add_argument(
        "-md",
        "--md",
        nargs="*",
        action=MdAction,
        default=[],
        help="""Provide machine descriptions that you need to plot\n
    with ids names for example wall pf_active \n
    or with uris for example\n
    "imas:mdsplus?user=public;shot=116000;run=4;database=ITER_MD;version=3#wall",
    "imas:mdsplus?user=public;shot=116000;run=4;database=ITER_MD;version=3",
    "imas:hdf5?path=./testdb",
    "testpulse.nc"
    """,
    )
    parser.add_argument(
        "--show-labels",
        help="Show labels",
        action="store_true",
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    splitted_ids_info = args.uri.split("#")
    uri = splitted_ids_info[0]
    occurrence = 0
    if len(splitted_ids_info) == 2:
        ids_fragment = splitted_ids_info[1]

        splitted_ids_fragment = ids_fragment.split(":")
        if len(splitted_ids_fragment) == 2:
            ids_path_fragment = splitted_ids_fragment[1]

            splitted_ids_path_fragment = ids_path_fragment.split("/", 1)
            occurrence = int(splitted_ids_path_fragment[0])

    logger = setup_logger("module", stdout_level=logging.INFO)

    connection = DBMaster.get_connection(args)
    if connection is None:
        logger.critical("----> Aborted.")
        exit(1)

    # if md not provided then get it from list of idses
    if isinstance(args.md, bool) and args.md is True:
        args.md = [args.uri]

    if args.dd_update:
        ids_obj_equilibrium = connection.get("equilibrium", occurrence=occurrence, autoconvert=False)
        ids_obj_equilibrium = imas.convert_ids(ids_obj_equilibrium, connection.factory.version)
    else:
        ids_obj_equilibrium = connection.get("equilibrium", occurrence=occurrence, lazy=True, autoconvert=False)

    if ids_obj_equilibrium.time is not None:
        time_slice, time_value = get_nearest_time(ids_obj_equilibrium.time, args.time)
        view_object = EquilibriumView(ids_obj_equilibrium)

        title = f"2D Equilibrium at time {time_value:.3f}"
        database_text = ""
        if args.plots:
            compute_obj = EquilibriumCompute(ids_obj_equilibrium)
            profiles_1d_quantities = compute_obj.get_profiles_1d_quantities(time_slice, ["pressure", "q", "beta_pol"])
            p1dcounter = sum(1 for value in profiles_1d_quantities.values() if value.has_value)

            global_quantities = compute_obj.get_global_quantities(
                time_slice, ["q_min.value", "q_95", "li_3", "beta_tor", "energy_mhd"]
            )
            gcounter = sum(1 for value in global_quantities.values() if value["has_value"])

            total_plots = p1dcounter + gcounter

            if total_plots % 2 == 1:
                col_size = int(total_plots / 2) + 1
            else:
                col_size = int(total_plots / 2)

            col_size = col_size + 1
            canvas = PlotCanvas(2, col_size)
            ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0, rowspan=2)
            axes_list1 = []
            axes_list2 = []
            plotting_counter = 0
            for col in range(1, col_size):
                for row in [0, 1]:
                    if plotting_counter < p1dcounter:
                        axes_list1.append(canvas.add_axes(title="", xlabel="", row=row, col=col))
                    else:
                        if plotting_counter < total_plots:
                            axes_list2.append(canvas.add_axes(title="", xlabel="", row=row, col=col))
                    plotting_counter += 1
        else:
            canvas = PlotCanvas(1, 1)
            ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)

        canvas.update_style(args.rc)
        if args.md:
            idses = ""
            mduris = []
            for mduri in args.md:
                mduri = mduri.strip()
                uri_base = mduri.split("#")[0]
                is_uri_or_path = (
                    mduri.startswith("imas:")
                    or uri_base.startswith(".")
                    or uri_base.startswith("~")
                    or uri_base.lower().endswith((".nc"))
                    or "/" in uri_base
                    or "\\" in uri_base
                    or os.path.exists(uri_base)
                )
                if is_uri_or_path:
                    mduris.append(mduri)
                else:
                    idses += mduri + ","

            if idses:
                mduris.append(args.uri)
                ids_data = get_md_data(mduris, args.dd_update, idses=idses)
            else:
                ids_data = get_md_data(mduris, args.dd_update)
            plot_machine_description(ax1, ids_data)

        c_psi, c_rho = view_object.view_magnetic_poloidal_flux(ax1, time_slice, plot_rho=args.rho)
        if c_psi:
            cbar_psi = canvas.fig.colorbar(c_psi, ax=ax1, orientation="horizontal", pad=0.08, fraction=0.03)
            cbar_psi.set_label(r"$\psi$ [Wb]")
        if c_rho:
            cbar_rho = canvas.fig.colorbar(c_rho, ax=ax1, orientation="horizontal", pad=0.08, fraction=0.03)
            cbar_rho.set_label(r"$\rho$ [Wb]")
        ax1.set_title(title)

        xmin, xmax = ax1.get_xlim()
        ymin, ymax = ax1.get_ylim()
        ax1.text(
            xmax + 0.01 * abs(xmax),
            ymin + 0.5 * abs(ymax - ymin),
            f"{get_database_path(args, time_value=time_value)}\n{database_text}",
            horizontalalignment="left",
            verticalalignment="center",
            rotation="vertical",
            fontsize=7,
        )
        if args.plots:
            view_object.plot_profiles_1d_quantities(axes_list1, time_slice)
            view_object.plot_global_quantities(axes_list2, time_value)

        canvas.fig.suptitle(get_title(args, "Equilibrium", time_value))
        canvas.fig.set_size_inches(14, 8)
        canvas.fig.subplots_adjust(top=0.933, bottom=0.05, left=0.024, right=0.988, hspace=0.221, wspace=0.25)
        canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
        if args.save:
            fname = get_file_name(args, f"{os.path.basename(__file__)}_Equilibrium", time_value)
            if args.directory:
                if not os.path.exists(args.directory):
                    os.makedirs(args.directory)
                fname = os.path.join(args.directory, fname)
            canvas.save(fname)
        else:
            canvas.show()
    else:
        logger.warning("Can not produce plot, equilibrium/time is None")
