#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import logging
import os

try:
    import imaspy as imas

except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idshelper import get_available_ids_and_occurrences
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.core_profiles import CoreProfilesView
from idstools.view.equilibrium import EquilibriumView
from idstools.view.summary import SummaryView

# ----------------------------------------------------------------------
if __name__ == "__main__":
    # Management of input arguments
    parser = argparse.ArgumentParser(
        description="---- Display the plasma kinetic profiles and equilibrium"
        " from the core_profiles and equilibrium IDSs [previously known as scenplot]",
        parents=[dbentry_parser, rcparam_parser],
        formatter_class=RichHelpFormatter,
    )

    group = parser.add_mutually_exclusive_group()
    group.add_argument("-t", "--time", help="Time for profiles", required=False, type=float, default=-99.0)
    group.add_argument(
        "-n",
        "--no-profiles",
        action="store_true",
        help="Do not plot profiles or equilibrium",
    )
    parser.add_argument(
        "-i",
        "--info",
        action="store_true",
        help="Add title with additional provenance information",
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()
    logger = setup_logger("module", stdout_level=logging.INFO)

    # Open the database and read the necessary IDSs
    connection = DBMaster.get_connection(args)
    availables_ids = get_available_ids_and_occurrences(connection)

    if connection is None:
        logger.critical("----> Aborted.")
        exit(1)
    time_value = 0.0
    time_slice = 0
    time_array = None
    if not args.no_profiles:
        ids_core_profiles = None
        for idsname, occ in availables_ids:
            if idsname == "core_profiles":
                try:
                    if args.dd_update:
                        ids_core_profiles = connection.get("core_profiles", occurrence=occ, autoconvert=False)
                        ids_core_profiles = imas.convert_ids(ids_core_profiles, connection.factory.version)
                    else:
                        ids_core_profiles = connection.get(
                            "core_profiles", occurrence=occ, lazy=True, autoconvert=False
                        )
                    break
                except Exception as e:
                    logger.error(f"core_profiles ids is not present {e}")
        if ids_core_profiles:
            if ids_core_profiles.time is None:
                ids_core_profiles = None
                logger.critical("core_profiles IDS time is empty")
            else:
                time_array = ids_core_profiles.time
                ntime = len(ids_core_profiles.time)
                time_slice, time_value = get_nearest_time(time_array, args.time)

        ids_equilibrium = None
        for idsname, occ in availables_ids:
            if idsname == "equilibrium":
                try:
                    if args.dd_update:
                        ids_equilibrium = connection.get("equilibrium", occurrence=occ, autoconvert=False)
                        ids_equilibrium = imas.convert_ids(ids_equilibrium, connection.factory.version)
                    else:
                        ids_equilibrium = connection.get("equilibrium", occurrence=occ, lazy=True, autoconvert=False)
                    break
                except Exception as e:
                    logger.error(f"equilibrium ids is not present, detailed error: {e}")
        if ids_equilibrium:
            if ids_equilibrium.time is None:
                ids_equilibrium = None
                logger.critical("equilibrium IDS time is empty")
            else:
                if time_array is not None:
                    # Search for adequate time slice for display
                    time_array = ids_equilibrium.time
                    ntime = len(ids_equilibrium.time)
                    time_slice, time_value = get_nearest_time(time_array, args.time)
    ids_summary = None

    for idsname, occ in availables_ids:
        if idsname == "summary":
            try:
                if args.dd_update:
                    ids_summary = connection.get("summary", occurrence=occ, autoconvert=False)
                    ids_summary = imas.convert_ids(ids_summary, connection.factory.version)
                else:
                    ids_summary = connection.get("summary", occurrence=occ, lazy=True, autoconvert=False)
                break
            except Exception as e:
                logger.critical(f"The summary IDS is absent from the input data-entry, detailed error: {e}")

    if ids_summary:
        if not args.no_profiles:
            canvas = PlotCanvas(4, 3)
        else:
            canvas = PlotCanvas(4, 1)

        ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0, colspan=2)
        ax2 = canvas.add_axes(title="", xlabel="", row=1, col=0, sharex=ax1, colspan=2)
        ax3 = canvas.add_axes(title="", xlabel="", row=2, col=0, sharex=ax1, colspan=2)
        ax4 = canvas.add_axes(title="", xlabel="", row=3, col=0, sharex=ax1, colspan=2)
        if not args.no_profiles:
            ax5 = canvas.add_axes(title="", xlabel="", row=0, col=2, rowspan=2)
            ax6 = canvas.add_axes(title="", xlabel="", row=2, col=2)
            ax7 = canvas.add_axes(title="", xlabel="", row=3, col=2)
    else:
        logger.critical(f"summary IDS is absent from the input data-entry")
        canvas = PlotCanvas(2, 2)
        if not args.no_profiles:
            ax5 = canvas.add_axes(title="", xlabel="", row=0, col=0, rowspan=2)
            ax6 = canvas.add_axes(title="", xlabel="", row=0, col=1)
            ax7 = canvas.add_axes(title="", xlabel="", row=1, col=1)

    canvas.update_style(args.rc)
    if ids_summary:
        summary_time = len(ids_summary.time)

        if summary_time < 1:
            logger.critical("The summary IDS is absent from the input data-entry")
        # H&CD waveforms
        obje_summary_view = SummaryView(ids_summary)
        obje_summary_view.view_hcd_waveforms(ax1)

        obje_summary_view.view_hmode(ax1)
        ax1.tick_params(axis="x", which="both", bottom=False, top=False, right=False, labelbottom=False)

        if not args.no_profiles:
            ax1.set_title("Profiles displayed for t = " + "%.1f" % time_value + " s")
            obje_summary_view.view_time_line(ax1, time_value)

        obje_summary_view.view_ip_b0_waveforms(ax2)

        obje_summary_view.view_hmode(ax2)
        ax2.tick_params(axis="x", which="both", bottom=False, top=False, right=False, labelbottom=False)
        if not args.no_profiles:
            obje_summary_view.view_time_line(ax2, time_value)

        obje_summary_view.view_energy_content_waveforms(ax3)

        obje_summary_view.view_hmode(ax3)
        ax3.tick_params(axis="x", which="both", bottom=False, top=False, right=False, labelbottom=False)
        if not args.no_profiles:
            obje_summary_view.view_time_line(ax3, time_value)

        obje_summary_view.view_vloop_waveforms(ax4)

        obje_summary_view.view_hmode(ax4)

        if not args.no_profiles:
            obje_summary_view.view_time_line(ax4, time_value)

    if not args.no_profiles:
        if ids_core_profiles:
            core_core_profiles_view = CoreProfilesView(ids_core_profiles)
            core_core_profiles_view.view_q_profile_and_magnetic_shear_profile(ax6, time_slice)

            core_core_profiles_view.view_current_density_profiles(ax7, time_slice)
        if ids_equilibrium:
            plotequilibrium = EquilibriumView(ids_equilibrium)
            plotequilibrium.plotequilibrium(ax5, time_slice)

    title = get_title(args, "Scenario")
    if not args.no_profiles:
        title = get_title(args, "Scenario", time_value)
    else:
        title = get_title(args, "Scenario")
    if args.info:
        title += (
            f"\nprovider={ids_summary.ids_properties.provider}, "
            f"creation_date={ids_summary.ids_properties.creation_date}\n"
            f"data_dictionary={ids_summary.ids_properties.version_put.data_dictionary}, "
            f"access_layer={ids_summary.ids_properties.version_put.access_layer}"
        )

    if not args.no_profiles:
        canvas.set_text(text=f"{get_database_path(args, time_value=time_value)}")
    else:
        canvas.set_text(text=f"{get_database_path(args)}")
    canvas.set_sup_title(title)
    canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
    canvas.fig.subplots_adjust(top=0.914, bottom=0.099, left=0.042, right=0.9, hspace=0.113, wspace=0.43)

    if args.save:
        if not args.no_profiles:
            fname = get_file_name(args, os.path.basename(__file__) + "_summary", time_value)
        else:
            fname = get_file_name(args, os.path.basename(__file__) + "_Scenario_shot")
        if args.directory:
            if not os.path.exists(args.directory):
                os.makedirs(args.directory)
            fname = os.path.join(args.directory, fname)
        canvas.save(fname)
    else:
        canvas.show()
