#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import glob
import os
import re
import sys
from operator import itemgetter
from stat import *

import yaml

try:
    from yaml import CLoader as yamlLoader
except ImportError:
    from yaml import Loader as yamlLoader
# ----------------------------------------------------------------------------


# For better formatting of input argument listing in the help of the function
class SmartFormatter(argparse.HelpFormatter):
    def _split_lines(self, text, width):
        if text.startswith("R|"):
            return text[2:].splitlines()
        return argparse.HelpFormatter._split_lines(self, text, width)


# ----------------------------------------------------------------------------


def rec(
    data={},
    paths=os.getcwd(),
    extension=".yaml",
    choice=["shot", "run", "dd_version", "ip", "b0"],
    selection=None,
    obsolete_cases=False,
):

    trunc_number = 70  # Need to truncate the reference name when it is too long

    # --------------------
    # Read all YAML files
    # --------------------
    table = {}
    yaml_files = []

    for path in paths:
        files = glob.glob(path + "/*")

        # Walk through the directory and find .yaml files
        for root, _, files in os.walk(path):
            for file in files:
                if file.endswith(".yaml"):
                    yaml_files.append(os.path.join(root, file))

    files = yaml_files

    # Fill data dictionary with yaml input files describing simulations
    j = -1

    for scenario_file_path in files:
        if "*.yaml" not in scenario_file_path:
            try:
                j = j + 1
                # file = open(scenario_file_path, "r")
                with open(scenario_file_path, "r") as scenario_file:
                    data[j] = yaml.load(scenario_file, Loader=yaml.Loader)

                    data[j]["location"] = scenario_file_path
                    data[j]["dd_version"] = "3"
                    if r"ITER_DISRUPTIONS/4" in scenario_file_path:
                        data[j]["dd_version"] = "4"
            except Exception as e:
                print("Error reading yaml " + scenario_file_path, file=sys.stderr)

    # ---------------------------------------------------------------
    # Sort all YAML files by shot and run numbers in sdata structure
    # ---------------------------------------------------------------
    # Sort data as a function of shot and then run numbers
    sorted_indices = sorted(
        data,
        key=lambda x: data[x]["characteristics"]["shot"] + 0.001 * data[x]["characteristics"]["run"],
    )
    sdata = {}
    for i in range(len(data)):
        sdata[i] = data[sorted_indices[i]]

    # ---------------------------------
    # Reference name for each variable
    # ---------------------------------
    # Names to be displayed
    disp_ref_name = "Reference"
    disp_ro_name = "RO(s)"
    disp_shot = "Shot"
    disp_run = "Run"
    disp_type = "Type"
    disp_dis_type = "Type"
    disp_VD_dir = "VD"
    disp_HF = "HF"
    disp_workflow = "Workflow"
    disp_database = "Database"
    disp_confinement = "Confinement"
    disp_ip = "Ip[MA]"
    disp_IREmax = "I_RE_max[MA]"
    disp_b0 = "B0[T]"
    disp_ne0 = "ne0[m-3]"
    disp_idslist = "List of IDSs"
    disp_tsteps = "#time steps"
    disp_dd_version = "DD Version"

    # ---------------------------------------------------
    # Calculate the optimized tabulation for each column
    # ---------------------------------------------------
    tab_ref_name = 0
    tab_ro_name = 0
    tab_shot = 0
    tab_run = 0
    tab_type = 0
    tab_workflow = 0
    tab_database = 0
    tab_confinement = 0
    tab_dis_type = 0
    tab_VD_dir = 0
    tab_HF = 0
    tab_ip = 0
    tab_IREmax = 0
    tab_b0 = 0
    tab_ne0 = 0
    tab_idslist = 0
    tab_tsteps = 0
    tab_dd_version = 0

    for i in range(len(sdata)):
        tab_ref_name = max(
            max(len(str(sdata[i].get("reference_name")[0:trunc_number])), tab_ref_name),
            len(disp_ref_name),
        )
        tab_ro_name = max(
            max(len(str(sdata[i].get("responsible_name"))), tab_ro_name),
            len(disp_ro_name),
        )
        tab_shot = max(
            max(len(str(sdata[i].get("characteristics").get("shot"))), tab_shot),
            len(disp_shot),
        )
        tab_run = max(
            max(len(str(sdata[i].get("characteristics").get("run"))), tab_run),
            len(disp_run),
        )
        tab_dd_version = max(
            max(len(str(sdata[i].get("characteristics").get("dd_version"))), tab_dd_version),
            len(disp_dd_version),
        )
        tab_type = max(
            max(len(str(sdata[i].get("characteristics").get("type"))), tab_type),
            len(disp_type),
        )
        tab_workflow = max(
            max(len(str(sdata[i].get("characteristics").get("workflow"))), tab_workflow),
            len(disp_workflow),
        )
        tab_database = max(
            max(len(str(sdata[i].get("characteristics").get("machine"))), tab_database),
            len(disp_database),
        )
        tab_confinement = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("confinement_regime"))),
                tab_confinement,
            ),
            len(disp_confinement),
        )
        tab_dis_type = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("disruption_type"))),
                tab_dis_type,
            ),
            len(disp_dis_type),
        )
        tab_VD_dir = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("VD_direction"))),
                tab_VD_dir,
            ),
            len(disp_VD_dir),
        )
        tab_HF = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("halo_fraction"))),
                tab_HF,
            ),
            len(disp_HF),
        )
        tab_ip = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("plasma_current"))),
                tab_ip,
            ),
            len(disp_ip),
        )
        tab_IREmax = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("I_RE_max"))),
                tab_IREmax,
            ),
            len(disp_IREmax),
        )
        tab_b0 = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("magnetic_field"))),
                tab_b0,
            ),
            len(disp_b0),
        )
        tab_ne0 = max(
            max(
                len(str(sdata[i].get("scenario_key_parameters").get("central_electron_density"))),
                tab_ne0,
            ),
            len(disp_ne0),
        )
        tab_idslist = max(max(len(str(sdata[i].get("ids_list"))), tab_idslist), len(disp_idslist))
        tab_tsteps = max(max(len(str(sdata[i].get("tsteps"))), tab_tsteps), len(disp_tsteps))

    # --------------
    # Table heading
    # --------------
    table = {}
    table[1] = eval("disp_" + str(choice[0]))
    for i in range(1, len(choice)):
        space_length = eval("tab_" + str(choice[i - 1])) - len(eval("disp_" + str(choice[i - 1]))) + 2
        try:
            table[1] = table[1] + " " * space_length + eval("disp_" + str(choice[i]))
        except Exception as e:
            print("The argument " + str(choice[i]) + " is not correct.", file=sys.stderr)
            print("----> Aborted.", file=sys.stderr)
            exit()
    dimline = len(table[1])

    # ---------------
    # Fill the table
    # ---------------
    j = 2
    for i in range(len(sdata)):
        if (sdata[i]["status"] == "active") or obsolete_cases:
            j = j + 1
            ref_name = sdata[i].get("reference_name")[0:trunc_number]
            ro_name = sdata[i].get("responsible_name")
            shot = sdata[i].get("characteristics").get("shot")
            run = sdata[i].get("characteristics").get("run")
            type = sdata[i].get("characteristics").get("type")
            workflow = sdata[i].get("characteristics").get("workflow")
            database = sdata[i].get("characteristics").get("machine")
            confinement = sdata[i].get("scenario_key_parameters").get("confinement_regime")
            ip = sdata[i].get("scenario_key_parameters").get("plasma_current")
            IREmax = sdata[i].get("scenario_key_parameters").get("I_RE_max")
            b0 = sdata[i].get("scenario_key_parameters").get("magnetic_field")
            ne0 = sdata[i].get("scenario_key_parameters").get("central_electron_density")
            dis_type = sdata[i].get("scenario_key_parameters").get("disruption_type")
            VD_dir = sdata[i].get("scenario_key_parameters").get("VD_direction")
            HF = sdata[i].get("scenario_key_parameters").get("halo_fraction")
            idslist = sdata[i].get("idslist")
            tsteps = sdata[i].get("tsteps")
            dd_version = sdata[i].get("dd_version")

            # Number of time steps in the scenario
            tstep = -99
            if "summary" in idslist.keys():
                tsteps = idslist["summary"]["time_step_number"]

            # Compute the list of IDSs to be displayed
            idslist_nice = str()
            for k, v in idslist.items():
                idslist_nice = idslist_nice + str(k) + " "
            idslist = idslist_nice

            # Content of the table
            table[j] = str(eval(choice[0]))
            for ic in range(1, len(choice)):
                if choice[ic - 1] == "composition":
                    space_length = space_length_compo
                else:
                    space_length = eval("tab_" + str(choice[ic - 1])) - len(str(eval(choice[ic - 1]))) + 2

                if choice[ic] == "ip":
                    if float(eval(choice[ic])) < 10:
                        space_length = space_length + 1  # To align decimals

                if choice[ic - 1] == "ip":
                    if float(eval(choice[ic - 1])) < 10:
                        space_length = space_length - 1  # To re-align following column

                if choice[ic] == "composition":
                    unsorted_species = composition.split()
                    nspecies = len(unsorted_species)
                    # Sort species as a function of concentration
                    if nspecies > 1:
                        unsorted_xspec = composition_x.split()
                    else:
                        unsorted_xspec[0] = str(composition_x)
                    sorted_x = sorted(
                        range(nspecies),
                        key=lambda x: float(unsorted_xspec[x]),
                        reverse=True,
                    )
                    species = []
                    xspec = []
                    for ispec in range(nspecies):
                        species.append(unsorted_species[sorted_x[ispec]])
                        xspec.append(unsorted_xspec[sorted_x[ispec]])
                    compo = ""
                    for ispec in range(nspecies):
                        compo = compo + species[ispec] + "(" + xspec[ispec] + "),"
                    compo = compo[:-1]
                    table[j] = table[j] + " " * space_length + compo
                    space_length_compo = eval("tab_composition") - len(compo) + 2
                else:
                    table[j] = table[j] + " " * space_length + str(eval(choice[ic]))
                if len(table[j]) > dimline:
                    dimline = len(table[j])

    # Separation lines
    for iline in 0, 2, j + 1:
        table[iline] = "-" * (dimline + 1)

    # Final filtering if any
    if selection != None:
        filtered_table = []
        filter = eval(str("""['""" + re.sub(""",""", r"""','""", selection) + """']"""))
        criterion = "zip("
        for ifilter in range(len(filter)):
            if ifilter == 0:
                criterion = criterion + """re.findall('""" + str(filter[ifilter]) + """',table[iline])"""
            else:
                criterion = criterion + """,re.findall('""" + str(filter[ifilter]) + """',table[iline])"""
        criterion = criterion + """)"""
        for iline in range(len(table)):
            if iline in [0, 1, 2, j + 1] or len(list(eval(criterion))) > 0:
                filtered_table.append(table[iline])
        table = filtered_table

    return table


# ----------------------------------------------------------------------

if __name__ == "__main__":

    # Input arguments
    parser = argparse.ArgumentParser(
        description="---- Script to list available disruptions in a specific folder ----",
        formatter_class=SmartFormatter,
    )
    parser.add_argument(
        "-f",
        "--folder",
        help="folder where to search for disruptions (recursive)",
        required=False,
    )
    parser.add_argument(
        "-s",
        "--selection",
        help="R|list of fields to filter: e.g. MD,up,2.65\n" "----> Select only disruptions filling these criteria",
        required=False,
    )
    parser.add_argument(
        "-o",
        "--obsolete",
        help="Show also obsolete cases",
        required=False,
        action="store_true",
    )
    parser.add_argument(
        "-c",
        "--choice",
        help="R|list of variables to display, e.g.: shot,run,ip,b0\n"
        "... available among following variables:\n"
        "        ref_name    = dataset reference name\n"
        "        ro_name     = resonsible officer name\n"
        "        shot        = shot number\n"
        "        run         = run number\n"
        "        type        = data type (experimental,predictive,interpretative)\n"
        "        dis_type    = which type of disruption (MD, VDE...) \n"
        "        VD_dir      = direction of vertical displacement (up, down, central) \n"
        "        HF          = poloidal halo current fraction (HF=Ipol,halo/Ip)       \n"
        "        workflow    = suite of codes used to compute these data\n"
        "        database    = database name\n"  #                        "        confinement = confinement regime (L or H)\n"\
        "        ip          = plasma current\n"
        "        IREmax      = maximum RE current\n"
        "        b0          = central magnetic field \n"
        "        ne0         = central electron density\n"
        "        idslist     = List of IDSs available in the data-entry\n"
        "        tsteps      = Number of time steps in the disruptions\n"
        "        dd_version  = Data dictionary version\n",
        required=False,
    )

    args = vars(parser.parse_args())

    obsolete_cases = args["obsolete"]

    # Folder paths
    paths = []
    if args["folder"] != None:
        paths.append(args["folder"])
    else:
        paths.append("/work/imas/shared/imasdb/ITER_DISRUPTIONS/3/0")
        paths.append("/work/imas/shared/imasdb/ITER_DISRUPTIONS/4")

    # Filter
    if args["selection"] != None:
        selection = args["selection"]
    else:
        selection = None

    # Choice of variables to display
    if args["choice"] != None:
        try:
            input_choice = args["choice"]
            choice = eval(str("""['""" + re.sub(""",""", r"""','""", input_choice) + """']"""))
        except Exception as e:
            print(
                "------------------------------------------------------------------",
                file=sys.stderr,
            )
            print("Wrong way to write your list of variables to display", file=sys.stderr)
            print("You wrote " + str(args["choice"]), file=sys.stderr)
            print(
                "Try e.g. this way: " + """ disruption_summary -c shot,run,ip     """,
                file=sys.stderr,
            )
            print(
                "------------------------------------------------------------------",
                file=sys.stderr,
            )
            sys.exit(1)
    else:
        choice = [
            "shot",
            "run",
            "dd_version",
            "ip",
            "b0",
            "ne0",
            "dis_type",
            "VD_dir",
            "IREmax",
            "HF",
            "workflow",
            "ref_name",
        ]

    # Default extension, search folder and choice of variables to display
    data = {}
    ext = ".yaml"

    # Print the arguments of the default call to the function
    if args["choice"] == None:
        print("""----> Default call equivalent to: """)
        print(
            """      disruption_summary -c shot,run,dd_version,ip,b0,ne0,dis_type,VD_dir,IREmax,HF,workflow,ref_name"""
        )

    # Call the function
    table = rec(
        data=data,
        paths=paths,
        extension=ext,
        choice=choice,
        selection=selection,
        obsolete_cases=obsolete_cases,
    )

    # Display the clean formatted table
    for i in range(len(table)):
        print(table[i])
    print(" NOTE: Default read entry from Disruption database using user = 'public', database = 'ITER_DISRUPTIONS'")
