#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import logging
import os

from rich_argparse import RichHelpFormatter

from idstools.compute.common import get_nearest_time
from idstools.database import read_scenario, read_scenario_with_args
from idstools.input_processing import (
    beam_wall_crossing,
    check_rays_into_divertor,
    read_launching_parameters,
    read_torbeam_output,
    read_wall,
)
from idstools.utils.clihelper import (
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
    dbentry_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.core_profiles import CoreProfilesView
from idstools.view.domain.ecstray import EcStrayView
from idstools.view.equilibrium import EquilibriumView
from idstools.view.polygon import PolygonView
from idstools.view.waves import WavesView

logger = setup_logger("module", stdout_level=logging.INFO)
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="---- Shows electron cyclotron stray radiation information by showing different plots",
        formatter_class=RichHelpFormatter,
        parents=[dbentry_parser, rcparam_parser],
    )
    parser.add_argument("-t", "--time", help="Time (default=middle)", type=float, default=-99.0)
    parser.add_argument(
        "--logscale",
        help="Shows y axis with logarithmic scale wherever appropriate",
        action="store_true",
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    time_index_equilibrium = 0
    time_index_core_profiles = 0
    time_index_waves = 0

    current_file_path = os.path.dirname(os.path.abspath(__file__))

    scenario_file = os.path.join(current_file_path, "../resources/input/scenario.yaml")
    wallfile = os.path.join(current_file_path, "../resources/input/wall2d.txt")
    filelaunchers = os.path.join(current_file_path, "../resources/input/ec_waveforms.yaml")
    path_result = os.path.join(current_file_path, "../resources/results/")

    if not os.path.exists(scenario_file):
        scenario_file = os.path.join(current_file_path, "input/scenario.yaml")
    if not os.path.exists(wallfile):
        wallfile = os.path.join(current_file_path, "input/wall2d.txt")
    if not os.path.exists(filelaunchers):
        filelaunchers = os.path.join(current_file_path, "input/ec_waveforms.yaml")
    if not os.path.exists(path_result):
        path_result = os.path.join(current_file_path, "results/")

    wall2d = read_wall(wallfile)

    # # Read launching parameters from EC waveform file
    launching_parameters = read_launching_parameters(filelaunchers)
    # # Read beam extra variables from Torbeam output ascii files
    beam_output, time_array_waves = read_torbeam_output(launching_parameters, path_result)
    # # Check if rays go into the divertor
    check_rays_into_divertor(wall2d, beam_output)
    # # Calculates where the beams cross the wall
    beam_wall = beam_wall_crossing(wall2d, launching_parameters, beam_output)

    if args.uri is None:
        in_ids_list, out_ids_dict, inputargs = read_scenario(
            scenario_file_path=scenario_file,
            in_ids_list=["equilibrium", "core_profiles"],
            out_ids_list=["waves"],
        )
        args.uri = (
            f"imas:mdsplus?user={inputargs.user};pulse={inputargs.pulse};"
            f"run={inputargs.run};database={inputargs.database};version={inputargs.version}"
        )
    else:
        data = read_scenario_with_args(
            imasargs=args,
            in_ids_list=["equilibrium", "core_profiles"],
            out_ids_list=["waves"],
        )
        if data is not None:
            in_ids_list, out_ids_dict = data

    equilibrium_ids = in_ids_list["equilibrium"]
    core_profiles_ids = in_ids_list["core_profiles"]
    waves_ids = out_ids_dict["waves"]

    time_array_equilibrium = equilibrium_ids.time  # Plot Ip
    time_array_core_profiles = core_profiles_ids.time
    time_array_waves = waves_ids.time

    time_slice = args.time

    # # Indices for time arrays in equilibrium, core_profiles, waves IDSs
    time_index_equilibrium, time_value_equilibrium = get_nearest_time(time_array_equilibrium, time_slice)
    time_index_core_profiles, time_value_core_profiles = get_nearest_time(time_array_core_profiles, time_slice)
    time_index_waves, time_value_waves = get_nearest_time(time_array_waves, time_slice)
    # Subplot waveforms versus time
    equilibrium_view = EquilibriumView(equilibrium_ids)
    core_profiles_view = CoreProfilesView(core_profiles_ids)
    waves_view = WavesView(waves_ids)
    ecstray_view = EcStrayView(equilibrium_ids, core_profiles_ids, waves_ids)

    canvas = PlotCanvas(3, 3)
    canvas.update_style(args.rc)
    ax_top_view = canvas.add_axes(title="Top View (X,Y)", xlabel="X [m]", ylabel="Y [m]", row=0, col=0, rowspan=2)
    ax_pol_view = canvas.add_axes(title="Poloidal view (R,Z)", xlabel="R [m]", ylabel="Z [m]", row=0, col=1, rowspan=2)

    ax_polygon = canvas.add_axes(
        title="Beam footprints on the wall",
        xlabel=r"$\phi \times R_{max}$ [Rad.m]",
        ylabel="Length along polygon [m]",
        row=2,
        col=0,
    )
    ax_waveform = canvas.add_axes(title="Current and Electrons density waveforms", xlabel="Time [s]", row=0, col=2)
    ax_density = canvas.add_axes(
        xlabel=r"Normalised $\rho_{tor}$ [-]",
        ylabel="Density [m-3]",
        row=1,
        col=2,
    )
    ax_beam_index = canvas.add_axes(title="Beam index", xlabel="Beam index", row=2, col=1, colspan=2)

    equilibrium_view.plot_ip(ax_waveform)  # Plot Ip
    core_profiles_view.plot_electron_density_ne0(ax_waveform)
    waves_view.plot_beam_index(ax_beam_index)
    equilibrium_view.plot_poloidal_equilibrium(ax_pol_view, time_index_equilibrium)
    coherent_wave_index = 0
    waves_view.plot_poloidal_traces_update(ax_pol_view, time_index_waves, verbose=True)
    ecstray_view.plot_resonance_layer(ax_pol_view, coherent_wave_index, time_index_waves, verbose=True)
    ecstray_view.plot_cut_off_layer(ax_pol_view, coherent_wave_index, time_index_waves)
    ax_topview_plot_eq = equilibrium_view.plot_topplotequilibrium(ax_top_view, time_index_equilibrium)
    ax_topview_plot_traces = waves_view.plot_topview_traces_update(ax_top_view, coherent_wave_index, time_index_waves)
    ax_density_plot_dens, nmax = core_profiles_view.plot_density_profile(
        ax_density, time_index_core_profiles, logscale=args.logscale
    )

    pview = PolygonView()
    ax_polygon_plot_pol = pview.plot_polygon(
        ax_polygon, wall2d, beam_wall, coherent_wave_index, time_index_waves, time_index_waves
    )

    canvas.set_text(text=f"{get_database_path(args, time_value=time_value_equilibrium)}")
    canvas.fig.subplots_adjust(
        top=0.88,
        bottom=0.11,
        left=0.035,
        right=0.902,
        hspace=0.458,
        wspace=0.234,
    )
    canvas.fig.suptitle(get_title(args, "EC Stray Radiation", time_value_equilibrium))
    canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
    if args.save:
        fname = get_file_name(args, f"{os.path.basename(__file__)}_Equilibrium", time_value_equilibrium)
        if args.directory:
            if not os.path.exists(args.directory):
                os.makedirs(args.directory)
            fname = os.path.join(args.directory, fname)
        canvas.save(fname)
    else:
        canvas.show()
