#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Evaluate performance of MDSPLUS and HDF5 Backends with IDS get operation

import argparse
import os
from datetime import datetime

try:
    import imaspy as imas
except ImportError:
    import imas
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rich.progress import track
from rich_argparse import RichHelpFormatter

from idstools.database import DBMaster
from idstools.idsperf import byte_size, get_timings
from idstools.utils.clihelper import imas_parser
from idstools.utils.idshelper import (
    get_available_ids_and_occurrences,
)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark backend performance by getting times of IDSes of given Pulse and Run, "
        "contrasts with IDS size and slices. replaced by db_backend_benchmark.py",
        parents=[imas_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument(
        "-md",
        "--mdsdatabase",
        type=str,
        help="Path to database in MDSPLUS (if separate to HDF5)",
    )
    parser.add_argument(
        "-mdu",
        "--mdsuser",
        type=str,
        help="User harboring database in MDSPLUS (if separate to HDF5)",
    )
    parser.add_argument(
        "-hd",
        "--hdf5database",
        type=str,
        help="Path to database in HDF5 (if separate to MDSPLUS)",
    )
    parser.add_argument(
        "-hdu",
        "--hdf5user",
        type=str,
        help="User harboring database in HDF5 (if separate to MDSPLUS)",
    )
    parser.add_argument(
        "-f",
        "--file",
        type=str,
        help="Path to already generated .csv file with backend performances",
    )
    parser.add_argument("-p", "--pulse", type=int, help="Pulse number harboring desired IDSs")
    parser.add_argument("--run", type=int, help="Run number harboring desired IDSs")
    parser.add_argument("--ids", nargs="*", type=str, help="IDS to test and/or plot")
    parser.add_argument(
        "-rep",
        "--repeat",
        type=int,
        default="10",
        help="Number of times IDS get should be performed",
    )
    parser.add_argument(
        "-excp",
        "--excludep",
        nargs="*",
        type=str,
        help="Tuple of Pulse and Runs to exclude",
    )
    parser.add_argument("-exci", "--excludeids", nargs="*", type=str, help="IDS(s) to exclude")
    parser.add_argument(
        "-rot",
        "--rotation",
        type=float,
        default=90.0,
        help="Number of degrees to tilt IDS names on x-axis (default=90)",
    )
    parser.add_argument(
        "-fw",
        "--fontweight",
        type=str,
        default="bold",
        help="Fontweight for plot axes (options include: 'ultralight', 'light', "
        "'normal', 'bold', 'heavy') (default='bold')",
    )
    parser.add_argument(
        "-siz",
        "--sizevtime",
        action="store_true",
        help="Evaluate performance based on time taken to read IDS size in MB",
    )
    parser.add_argument("-xlog", "--xlogscale", action="store_true", help="Scale x-axis with logscale")
    parser.add_argument("-ylog", "--ylogscale", action="store_true", help="Scale y-axis with logscale")
    parser.add_argument(
        "-sli",
        "--slicevtime",
        action="store_true",
        help="Evaluate performance based on time taken to read number of IDS slices",
    )
    parser.add_argument(
        "-tim",
        "--time",
        action="store_true",
        help="Evaluate performance based on time taken to read an IDS",
    )
    parser.add_argument(
        "--idsperf",
        action="store_true",
        help="Display performance of individual IDSs on all measured runs",
    )
    parser.add_argument(
        "--avgdf",
        action="store_true",
        help="Return dataframe and consequent plots of measured values' averages in evaluated pulse(s)",
    )
    parser.add_argument(
        "-verb",
        "--verbose",
        action="store_true",
        help="Verbose mode for printing used pulses and dataframes",
    )

    args = parser.parse_args()

    if not args.file:
        if args.database:
            mdslocpath = DBMaster.get_db_Path(args.user, args.database, args.version)
            hdflocpath = mdslocpath
        elif args.mdsdatabase and args.hdf5database:
            mdslocpath = DBMaster.get_db_Path(args.mdsuser, args.mdsdatabase, args.version)
            hdflocpath = DBMaster.get_db_Path(args.hdf5user, args.hdf5database, args.version)
        else:
            raise Exception(
                "Please clarify where the data is stored. Please provide either one database "
                "containing data in both MDSPLUS and HDF5, or two separate databases"
                "containing data in MDSPLUS or HDF5"
            )
        log = []

        if args.slicevtime:
            slices = []

        if args.run:
            pulses = [(args.pulse, args.run)]

        elif not args.pulse:
            hpulses = DBMaster.hdf5_list_pulse_run(hdflocpath)
            mpulses = DBMaster.mds_list_pulse_run(mdslocpath)

            if set(hpulses) != set(mpulses):
                mdsrem = set(mpulses) - set(hpulses)
                hdfrem = set(hpulses) - set(mpulses)
                pulses = list(set(mpulses) & set(hpulses))
                print(
                    "WARNING: The provided database(s) do(es) not have the same Pulses/Runs "
                    "in both backends. The intersection of the two will be used."
                )
                if args.verbose:
                    print(
                        """The differences are:
                    MDSPLUS has: """
                    )
                    print(list(mdsrem))
                    print(
                        """
                    """
                    )
                    print("while HDF5 has: ")
                    print(list(hdfrem))
                    print(
                        """
                    The intersection of the two databases that will be used is: """
                    )
                    print(pulses)

            else:
                pulses = hpulses

            if args.excludep:
                excludep = [int(x) for x in args.excludep]
                pulses = list(set(pulses) - set(excludep))

        for entry in pulses:
            pulse = entry[0]
            run = entry[1]

            mdsde = imas.DBEntry(
                imas.ids_defs.MDSPLUS_BACKEND, args.database, pulse, run, args.user, data_version=args.version
            )
            mdsde.open()

            hdfde = imas.DBEntry(
                imas.ids_defs.HDF5_BACKEND, args.database, pulse, run, args.user, data_version=args.version
            )
            hdfde.open()

            idslist = []
            if args.ids:
                idslist = [args.ids]
            else:
                idslist = get_available_ids_and_occurrences(hdfde)

            if args.excludeids:
                idslist = [i for i in idslist if i[0] not in args.excludeids]

            if args.verbose:
                print("Pulse " + str(pulse))
                print("Run " + str(run))
                print("MDSPLUS database entry opened")
                print("HDF5 database entry opened")
                print(idslist)

            for ids in idslist:
                print("Getting performance for ids:", ids)
                ids2process = ids[0]
                idsoccurrence = ids[1]
                ids_obj = hdfde.get(ids2process)
                size = byte_size(ids_obj) / 1024**2
                if args.slicevtime:
                    try:
                        slices.append(get_timings(hdfde, ids2process, times=[10], repeat=args.repeat))
                    except Exception as e:
                        slices.append(None)
                        print(f"{ids2process} in ({pulse} , {run}) failed to return slice get_timings - detailed:{e}")
                mdstimes = get_timings(mdsde, ids2process, repeat=args.repeat)
                mdstdev = np.std(mdstimes)
                hdftimes = get_timings(hdfde, ids2process, repeat=args.repeat)
                hdfstdev = np.std(hdftimes)
                mdsmeantime = sum(mdstimes) / len(mdstimes)
                hdfmeantime = sum(hdftimes) / len(hdftimes)
                log.append(
                    (
                        (pulse, run),
                        ids2process,
                        size,
                        mdsmeantime,
                        mdstdev,
                        hdfmeantime,
                        hdfstdev,
                    )
                )

            mdsde.close()
            hdfde.close()

            if args.verbose:
                print("MDSPLUS database closed")
                print("HDF5 database closed")
                print(log[-1])

        totaldf = pd.DataFrame(
            log,
            columns=[
                "(PULSE, RUN)",
                "IDS",
                "SIZE",
                "MDSPLUS_TIME",
                "MDSPLUS_TIME-ST. DEV.",
                "HDF5_TIME",
                "HDF5_TIME-ST. DEV.",
            ],
        ).sort_values("SIZE")

        if args.slicevtime:
            totaldf["SLICES"] = slices

        if args.sizevtime:
            totaldf["MDSPLUS_PERFORMANCE (MB/s)"] = totaldf.apply(lambda row: (row.SIZE / row.MDSPLUS_TIME), axis=1)
            totaldf["HDF5_PERFORMANCE (MB/s)"] = totaldf.apply(lambda row: (row.SIZE / row.HDF5_TIME), axis=1)

        if args.slicevtime:
            totaldf["MDSPLUS_PERFORMANCE (Slices/s)"] = totaldf.apply(
                lambda row: (row.SLICES / row.MDSPLUS_TIME), axis=1
            )
            totaldf["HDF5_PERFORMANCE (Slices/s)"] = totaldf.apply(lambda row: (row.SLICES / row.HDF5_TIME), axis=1)

        date = datetime.now().strftime("%Y_%m_%d-%I:%M:%S_%p")
        totaldf.to_csv("totaldf-" + date + ".csv", na_rep="None", index=False, header=True)

    if args.file:
        totaldf = pd.read_csv(args.file)

        if args.ids:
            totaldf = totaldf[(totaldf["IDS"].isin(args.ids))]

        if args.excludeids:
            totaldf.drop(totaldf[totaldf.IDS.isin(args.excludeids)].index, inplace=True)

    if args.verbose:
        print(totaldf)

    if args.avgdf:
        avgdf = totaldf.groupby("IDS").mean()

        if args.verbose:
            print(avgdf)

        if args.sizevtime:
            ax = avgdf.plot(
                title="Performance of all IDSs",
                y="MDSPLUS_PERFORMANCE (MB/s)",
                yerr="MDSPLUS_TIME ST. DEV.",
                capsize=10 if args.repeat > 1 else 0,
                marker=".",
            )
            avgdf.plot(
                y="HDF5_PERFORMANCE (MB/s)",
                yerr="HDF5_TIME ST. DEV.",
                capsize=10 if args.repeat > 1 else 0,
                marker=".",
                ax=ax,
            )
            plt.ylabel("PERFORMANCE (MB/s)", fontweight=args.fontweight)
            plt.xlabel("IDS", fontweight=args.fontweight)
            plt.xscale("log") if args.xlogscale else None
            plt.yscale("log") if args.ylogscale else None
            plt.xticks(np.arange(len(avgdf)), avgdf.index, rotation=args.rotation)

        if args.slicevtime:
            ax = avgdf.plot(
                title="Performance of all IDSs",
                y="MDSPLUS_PERFORMANCE (Slices/s)",
                yerr="MDSPLUS_TIME ST. DEV.",
                capsize=10 if args.repeat > 1 else 0,
                marker=".",
            )
            avgdf.plot(
                y="HDF5_PERFORMANCE (Slices/s)",
                yerr="HDF5_TIME ST. DEV.",
                capsize=10 if args.repeat > 1 else 0,
                marker=".",
                ax=ax,
            )
            plt.ylabel("PERFORMANCE (Slices/s)", fontweight=args.fontweight)
            plt.xlabel("IDS", fontweight=args.fontweight)
            plt.xscale("log") if args.xlogscale else None
            plt.yscale("log") if args.ylogscale else None
            plt.xticks(np.arange(len(avgdf)), avgdf.index, rotation=args.rotation)

        if not args.sizevtime and args.slicevtime or args.time:
            ax = avgdf.plot(
                title="Get Times for all IDSs",
                y="MDSPLUS_TIME",
                yerr="MDSPLUS_TIME ST. DEV.",
                capsize=10 if args.repeat > 1 else 0,
                marker=".",
            )
            avgdf.plot(
                y="HDF5_TIME",
                yerr="HDF5_TIME ST. DEV.",
                capsize=10 if args.repeat > 1 else 0,
                marker=".",
                ax=ax,
            )
            plt.ylabel("TIME (s)", fontweight=args.fontweight)
            plt.xlabel("IDS", fontweight=args.fontweight)
            plt.xscale("log") if args.xlogscale else None
            plt.yscale("log") if args.ylogscale else None
            plt.xticks(np.arange(len(avgdf)), avgdf.index, rotation=args.rotation)

        if not args.sizevtime and args.slicevtime and args.time:
            raise Exception("Please specify performance-measuring command line arguments (size, slices, time, etc.)")

        if not args.file:
            avgdf.to_csv(
                os.path.expanduser("~" + args.user) + "/public/avgdf-" + date + ".csv",
                na_rep="None",
                index=False,
                header=True,
            )

    if args.idsperf:
        evaldf = totaldf.groupby("IDS")

        if args.verbose:
            print(evaldf)

        for key in evaldf.groups.keys():
            if args.sizevtime:
                ax = evaldf.get_group(key).plot(
                    title=key,
                    x="SIZE",
                    y="MDSPLUS_PERFORMANCE (MB/s)",
                    yerr="HDF5_TIME ST. DEV.",
                    capsize=10 if args.repeat > 1 else 0,
                    marker=".",
                    color="blue",
                )
                evaldf.get_group(key).plot(
                    x="SIZE",
                    y="HDF5_PERFORMANCE (MB/s)",
                    yerr="HDF5_TIME ST. DEV.",
                    capsize=10 if args.repeat > 1 else 0,
                    marker=".",
                    color="orangered",
                    ax=ax,
                )
                plt.ylabel("PERFORMANCE (MB/s)", fontweight=args.fontweight)
                plt.xlabel("SIZE (MB)", fontweight=args.fontweight)
                plt.xscale("log") if args.xlogscale else None
                plt.yscale("log") if args.ylogscale else None

            if args.slicevtime:
                ax = evaldf.get_group(key).plot(
                    title=key,
                    x="SIZE",
                    y="MDSPLUS_PERFORMANCE (Slices/s)",
                    yerr="MDSPLUS_TIME ST. DEV.",
                    capsize=10 if args.repeat > 1 else 0,
                    marker=".",
                )
                evaldf.get_group(key).plot(
                    x="SIZE",
                    y="HDF5_PERFORMANCE (Slices/s)",
                    yerr="HDF5_TIME ST. DEV.",
                    capsize=10 if args.repeat > 1 else 0,
                    marker=".",
                    ax=ax,
                )
                plt.ylabel("PERFORMANCE (Slices/s)", fontweight=args.fontweight)
                plt.xlabel("SIZE (MB)", fontweight=args.fontweight)
                plt.xscale("log") if args.xlogscale else None
                plt.yscale("log") if args.ylogscale else None

            if not args.sizevtime and args.slicevtime or args.time:
                ax = evaldf.get_group(key).plot(
                    title=key,
                    x="SIZE",
                    y="MDSPLUS_TIME",
                    yerr="MDSPLUS_TIME ST. DEV.",
                    capsize=10 if args.repeat > 1 else 0,
                    marker=".",
                )
                evaldf.get_group(key).plot(
                    x="SIZE",
                    y="HDF5_TIME",
                    yerr="HDF5_TIME ST. DEV.",
                    capsize=10 if args.repeat > 1 else 0,
                    marker=".",
                    ax=ax,
                )
                plt.ylabel("TIME (s)", fontweight=args.fontweight)
                plt.xlabel("SIZE (MB)", fontweight=args.fontweight)
                plt.xscale("log") if args.xlogscale else None
                plt.yscale("log") if args.ylogscale else None

            if not args.sizevtime and args.slicevtime and args.time:
                raise Exception(
                    "Please specify performance-measuring command line arguments (size, slices, time, etc.)"
                )

    plt.show()
