#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# David.Coster@ipp.mpg.de
# migrated from check_transport by David.Coster@ipp.mpg.de

# V' [-D Y' <|grad(rho_tor_norm)|^2> + v Y <|grad(rho_tor_norm)|>]
# gm3(:) 	Flux surface averaged |grad_rho_tor|^2 {dynamic} [-]
# gm7(:) 	Flux surface averaged |grad_rho_tor| {dynamic} [-]
# V' [-D Y' gm3 + v Y gm7]

import argparse
import logging
import os
import sys

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas
from idstools.view.core_transport import CoreTransportView

logger = setup_logger("module", stdout_level=logging.INFO)

slicing_methods = {
    "CLOSEST": imas.ids_defs.CLOSEST_INTERP,
    "PREVIOUS": imas.ids_defs.PREVIOUS_INTERP,
    "LINEAR": imas.ids_defs.LINEAR_INTERP,
}
if __name__ == "__main__":
    # Management of input arguments
    parser = argparse.ArgumentParser(
        description="Check match between transport fluxes and a simple calculation [previously known as check_transport]",
        formatter_class=RichHelpFormatter,
        parents=[dbentry_parser, rcparam_parser],
    )

    parser.add_argument(
        "-m",
        "--slicingmethod",
        type=str,
        default="CLOSEST",
        choices=["CLOSEST", "PREVIOUS", "LINEAR"],
        help="Slicing method \t(default=%(default)s)",
    )

    parser.add_argument("-o", "--occurrence", type=int, default=0, help="occurrence")
    parser.add_argument("-t", "--time", type=float, help="Time", default=-99)
    parser.add_argument(
        "--logscale",
        help="Shows y axis with logarithmic scale wherever appropriate",
        action="store_true",
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()

    connection = DBMaster.get_connection(args)
    if connection is None:
        exit(1)

    ids_slice_core_transport = connection.get_slice(
        "core_transport",
        args.time,
        autoconvert=False,
        interpolation_method=slicing_methods[args.slicingmethod],
        occurrence=args.occurrence,
    )
    if args.dd_update:
        ids_slice_core_transport = imas.convert_ids(ids_slice_core_transport, connection.factory.version)
    if len(ids_slice_core_transport.time) == 0:
        logger.critical("IDS:core_transport Time vector is empty! aborting")
        sys.exit(1)

    ids_slice_core_profiles = connection.get_slice(
        "core_profiles",
        args.time,
        autoconvert=False,
        interpolation_method=slicing_methods[args.slicingmethod],
        occurrence=args.occurrence,
    )
    if args.dd_update:
        ids_slice_core_profiles = imas.convert_ids(ids_slice_core_profiles, connection.factory.version)
    if len(ids_slice_core_profiles.time) == 0:
        logger.critical("IDS:core_profiles Time vector is empty! aborting")
        sys.exit(1)

    ids_slice_equilibrium = connection.get_slice(
        "equilibrium",
        args.time,
        autoconvert=False,
        interpolation_method=slicing_methods[args.slicingmethod],
        occurrence=args.occurrence,
    )
    if args.dd_update:
        ids_slice_equilibrium = imas.convert_ids(ids_slice_equilibrium, connection.factory.version)
    if len(ids_slice_equilibrium.time) == 0:
        logger.critical("IDS:equilibrium Time vector is empty! aborting")
        sys.exit(1)

    time_value = ids_slice_core_transport.time[-1]

    try:
        model_index = [t.identifier.name for t in ids_slice_core_transport.model].index("transport_solver")
        logger.info(f'Found model "transport_solver" as model {model_index}')
    except Exception as e:
        logger.critical(f'"transport_solver" not found in the transport models {e}')
        exit()

    # define columns and rows for subplots based on ions
    columns = 2
    if len(ids_slice_core_transport.model[model_index].profiles_1d[-1].ion) > 1:
        columns = len(ids_slice_core_transport.model[model_index].profiles_1d[-1].ion)

    canvas = PlotCanvas(3, columns)
    canvas.update_style(args.rc)
    electrons_particle_fluxes_axes = canvas.add_axes(title="", xlabel="", row=0, col=0, colspan=1)
    electrons_energy_fluxes_axes = canvas.add_axes(title="", xlabel="", row=0, col=1, colspan=1)

    ions_particle_fluxes_axes = [canvas.add_axes(title="", xlabel="", row=1, col=x, colspan=1) for x in range(columns)]
    ions_energy_fluxes_axes = [canvas.add_axes(title="", xlabel="", row=2, col=x, colspan=1) for x in range(columns)]
    ct_view = CoreTransportView(ids_slice_core_transport)

    ct_view.view_ions_particle_fluxes(
        ions_particle_fluxes_axes,
        ids_slice_core_transport,
        ids_slice_core_profiles,
        ids_slice_equilibrium,
        -1,
        model_index,
        logscale=args.logscale,
    )
    ct_view.view_ions_energy_fluxes(
        ions_energy_fluxes_axes,
        ids_slice_core_transport,
        ids_slice_core_profiles,
        ids_slice_equilibrium,
        -1,
        model_index,
        logscale=args.logscale,
    )
    ct_view.view_particle_fluxes_for_electrons(
        electrons_particle_fluxes_axes,
        ids_slice_core_transport,
        ids_slice_core_profiles,
        -1,
        model_index,
        logscale=args.logscale,
    )
    ct_view.view_energy_fluxes_for_electrons(
        electrons_energy_fluxes_axes,
        ids_slice_core_transport,
        ids_slice_core_profiles,
        ids_slice_equilibrium,
        -1,
        model_index,
        logscale=args.logscale,
    )
    canvas.set_text(text=f"{get_database_path(args, time_value=time_value)}")
    canvas.fig.suptitle(get_title(args, "Core transport", time_value))
    canvas.fig.subplots_adjust(top=0.9, bottom=0.094, left=0.035, right=0.948, hspace=0.417, wspace=0.117)
    canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
    canvas.remove_empty_axes()
    if args.save:
        fname = get_file_name(args, os.path.basename(__file__) + "_core_transport", time_value)
        if args.directory:
            if not os.path.exists(args.directory):
                os.makedirs(args.directory)
            fname = os.path.join(args.directory, fname)
        canvas.save(fname)
    else:
        canvas.show()
