#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.distributions import DistributionsView

logger = setup_logger("module", stdout_level=logging.INFO)


def show_plots(args):
    conn = DBMaster.get_connection(args)
    if conn is None:
        logger.critical("----> Aborted.")
        exit(1)

    ids_distributions = None
    try:

        if args.dd_update:
            ids_distributions = conn.get("distributions", autoconvert=False)
            ids_distributions = imas.convert_ids(ids_distributions, conn.factory.version)
        else:
            ids_distributions = conn.get("distributions", lazy=True, autoconvert=False)
    except Exception as e:
        logger.error(f"distributions ids is not present, detailed error: {e}")

    if ids_distributions:
        time_array = ids_distributions.time
        ntime = len(time_array)
        time_slice, time_value = get_nearest_time(ids_distributions.time, args.time)

        canvas = PlotCanvas(3, 2)
        canvas.update_style(args.rc)
        # canvas.setStyle(style="retro")plo

        if ntime == 1:
            logger.info("Only one time slice --> Power and CD waveforms not displayed")
        ax1 = canvas.add_axes(title="", xlabel="", row=0, col=0)
        ax2 = canvas.add_axes(title="", xlabel="", row=1, col=0)
        ax3 = canvas.add_axes(title="", xlabel="", row=2, col=0)
        distributions_view = DistributionsView(ids_distributions)
        distributions_view.plot_absorbed_power_density_individual(ax1, time_slice)
        distributions_view.plot_absorbed_power_density(ax2, time_slice)
        distributions_view.plot_cd_profile(ax3, time_slice)
        if ntime != 1:
            ax4 = canvas.add_axes(title="", xlabel="", row=0, col=1)
            ax5 = canvas.add_axes(title="", xlabel="", row=1, col=1)
            distributions_view.plot_nbi_fus_power_and_cd_waveforms(ax4, time_slice)
            distributions_view.plot_cd_waveform(ax5, time_slice)

        canvas.set_text(text=f"{get_database_path(args, time_value=time_value)}")

        canvas.fig.subplots_adjust(
            top=0.92,
            bottom=0.122,
            left=0.044,
            right=0.886,
            hspace=0.438,
            wspace=0.328,
        )
        canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
        canvas.fig.suptitle(get_title(args, "Distributions profile", time_value))

        if args.save:
            fname = get_file_name(args, os.path.basename(__file__) + "_Distributions_profile", time_value)
            if args.directory:
                if not os.path.exists(args.directory):
                    os.makedirs(args.directory)
                fname = os.path.join(args.directory, fname)
            canvas.save(fname)
        else:
            canvas.show()
    conn.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="---- Display EC results",
        parents=[dbentry_parser, rcparam_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument("-t", "--time", help="Time", required=False, type=float, default=-99.0)

    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    show_plots(args)

    # pulse = args.pulse
    # run = args.run

    # To handle multiple datafiles (for scans)
    # if "-" in run:
    #     [runmin, runmax] = [int(x) for x in run.split("-")]  # int(run.split('-'))
    # else:
    #     runmin = int(run)
    #     runmax = int(run)

    # for irun in range(runmin, runmax + 1):
