#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os

import numpy as np
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
from rich_argparse import RichHelpFormatter

from idstools.database import DBMaster
from idstools.utils.idshelper import get_available_ids_and_times
from idstools.utils.idslogger import setup_logger

logger = setup_logger("module")

TAB = " " * 3
SHOT_STR_LEN = 6
RUN_STR_LEN = 5
DATABASE_STR_LEN = 12
VERSION_STR_LEN = 6
IDSNAME_STR_LEN = 26
SLICENUM_STR_LEN = 4
RUNNUM_STR_LEN = 4
TIME_STR_LEN = 4


def extended(str, wlen):
    return f"{str:>{wlen}s}"


def print_vector(vec, wlen):
    return " ".join([f"{str(x):>{wlen}s}" for x in vec])


def print_list(dbs, pulsenum=None, compact=False, timestamp=False):
    for dbname, dvs in dbs:
        printed_database = False
        for dv, dbbackends in dvs:
            printed_dataversion = False
            for backend, dbs in dbbackends:
                printed_backend = False
                for pulse, runs in sorted(dbs.items()):
                    if pulsenum is not None:
                        # if a pulsenum is given, only display databases with this pulse number
                        if pulse != pulsenum:
                            continue

                    if not printed_database:
                        print(f"Database: {dbname}")
                        printed_database = True
                    if not printed_dataversion:
                        print(f"{TAB }Data version: {dv}")
                        printed_dataversion = True
                    if not printed_backend:
                        print(TAB * 2 + "Backend: " + backend)
                        printed_backend = True
                    if compact:
                        print(
                            TAB * 3
                            + "Pulse "
                            + extended(str(pulse), SHOT_STR_LEN)
                            + ": "
                            + extended(str(len(sorted(runs))), RUNNUM_STR_LEN)
                            + " runs"
                        )

                    elif timestamp:
                        print(TAB * 3 + "Pulse " + extended(str(pulse), SHOT_STR_LEN))
                        runs.sort(reverse=True, key=lambda x: x[1])
                        for r in runs:
                            print(
                                TAB * 3
                                + " " * (5 + SHOT_STR_LEN)
                                + " Runs: "
                                + extended(str(r[1]), RUN_STR_LEN)
                                + " ("
                                + str(r[7])
                                + ")"
                            )
                    else:
                        print(
                            TAB * 3
                            + "Pulse "
                            + extended(str(pulse), SHOT_STR_LEN)
                            + " Runs: "
                            + print_vector(sorted([r[1] for r in runs]), RUN_STR_LEN)
                        )


def print_list_from_folder(dbs, pulsenum=None, compact=False, timestamp=False):
    for backend, pulses in dbs:
        printed_backend = False
        for pulse, details in pulses.items():
            if not printed_backend:
                print(TAB * 2 + "Backend: " + backend)
                printed_backend = True

            file_path = pulse.rsplit("/", 1)[0]
            uripath = f'"imas:{backend.lower()}?path={file_path}"'

            if pulsenum is not None:
                # if a pulsenum is given, only display databases with this pulse number
                if details[0][0] != pulsenum:
                    continue

            if timestamp:
                print(TAB * 3 + extended(str(uripath), SHOT_STR_LEN))
                print(TAB * 3 + " " * (5 + SHOT_STR_LEN) + " (" + str(details[0][4]) + ")")
            else:
                print(TAB * 3 + extended(str(uripath), SHOT_STR_LEN))


def print_urilist(dbs, pulsenum=None, compact=False, timestamp=False):
    urichoice = ["Sr. No.", "uri"]
    if timestamp:
        urichoice.append("timestamp")
    table = Table()
    for column_name in urichoice:
        if column_name == "timestamp":
            table.add_column(column_name, justify="left", style="red")
        elif column_name == "uri":
            table.add_column(column_name, justify="left", style="yellow")
        else:
            table.add_column(column_name, justify="left", style="green")
    counter = 0
    for dbname, dvs in dbs:
        for dv, dbbackends in dvs:
            for backend, dbs in dbbackends:
                for pulse, runs in sorted(dbs.items()):
                    if pulsenum is not None:
                        # if a pulsenum is given, only display databases with this pulse number
                        if pulse != pulsenum:
                            continue

                    for r in sorted(runs, key=lambda x: x[1]):
                        x = []
                        x.append(str(counter))
                        counter = counter + 1
                        x.append(
                            f'"imas:{backend.lower()}?user={args.user};pulse={pulse};'
                            f'run={r[1]};database={dbname};version={dv}"'
                        )
                        if timestamp:
                            x.append(str(r[7]))
                        table.add_row(*x)
    console = Console()
    console.print(table)


def print_urilist_from_folder(dbs, pulsenum=None, compact=False, timestamp=False):
    urichoice = ["Sr. No.", "uri"]
    if timestamp:
        urichoice.append("timestamp")
    table = Table()
    for column_name in urichoice:
        if column_name == "timestamp":
            table.add_column(column_name, justify="left", style="red")
        elif column_name == "uri":
            table.add_column(column_name, justify="left", style="yellow")
        else:
            table.add_column(column_name, justify="left", style="green")
    counter = 0
    for backend, pulses in dbs:
        for pulse, details in pulses.items():
            file_path = pulse.rsplit("/", 1)[0]
            uripath = f'"imas:{backend.lower()}?path={file_path}"'
            if pulsenum is not None:
                # if a pulsenum is given, only display databases with this pulse number
                if details[0] != pulsenum:
                    continue
            x = []
            x.append(str(counter))
            counter = counter + 1
            x.append(uripath)
            if timestamp:
                x.append(str(details[3]))
            table.add_row(*x)
    console = Console()
    console.print(table)


def print_times(dbs, args, print_times=False, pulse_number=None, run_number=None, showuri=False):

    for dbname, dvs in dbs:
        printed_database = False
        for dv, dbbackends in dvs:
            printed_dataversion = False
            for backend, dbs in dbbackends:
                printed_backend = False
                for pulse, runs in sorted(dbs.items()):
                    # If a pulsenum and/or runnum is given, only display these
                    if pulse_number is not None:
                        if pulse != pulse_number:
                            continue

                    printed_pulse = False
                    justruns = [r[1] for r in runs]
                    justruns.sort()
                    for run in justruns:
                        if run_number is not None:
                            if run != run_number:
                                continue

                        if not printed_database:
                            print(f"Database: {dbname}")
                            printed_database = True
                        if not printed_dataversion:
                            print(f"{TAB }Data version: {dv}")
                            printed_dataversion = True
                        if not printed_backend:
                            print(TAB * 2 + "Backend: " + backend)
                            printed_backend = True
                        if not printed_pulse:
                            print(TAB * 3 + "Pulse " + extended(str(pulse), SHOT_STR_LEN))
                            printed_pulse = True

                        connargs = argparse.Namespace()
                        connargs.backend = backend.upper()
                        connargs.pulse = pulse
                        connargs.run = run
                        connargs.user = args.user
                        connargs.database = dbname
                        connargs.version = dv
                        connargs.uri = (
                            f"imas:{backend.lower()}?user={args.user};pulse={pulse};"
                            f"run={run};database={dbname};version={dv}"
                        )
                        connection = DBMaster.get_connection(connargs)
                        if showuri:
                            print(TAB * 4 + '"' + connargs.uri + '"')
                        else:
                            print(TAB * 4 + " Run: " + extended(str(run), RUN_STR_LEN))
                        available_ids_and_times = get_available_ids_and_times(connection)
                        for idsname, times in available_ids_and_times:
                            if len(times) == 1 and np.isnan(times[0]):
                                print(
                                    TAB * 5
                                    + extended(idsname, IDSNAME_STR_LEN)
                                    + ": "
                                    + extended("?", SLICENUM_STR_LEN)
                                    + " slices ( "
                                    + "heterogeneous IDS )"
                                )
                            elif len(times) == 1 and times[0] == -np.inf:
                                print(
                                    TAB * 5
                                    + extended(idsname, IDSNAME_STR_LEN)
                                    + ": "
                                    + extended("?", SLICENUM_STR_LEN)
                                    + " slices ( "
                                    + "time independent IDS )"
                                )
                            elif print_times:
                                print(
                                    TAB * 5
                                    + extended(idsname, IDSNAME_STR_LEN)
                                    + ": "
                                    + extended(str(len(times)), SLICENUM_STR_LEN)
                                    + " slices ("
                                    + print_vector(times, TIME_STR_LEN)
                                    + ")"
                                )
                            else:
                                print(
                                    TAB * 5
                                    + extended(idsname, IDSNAME_STR_LEN)
                                    + ": "
                                    + extended(str(len(times)), SLICENUM_STR_LEN)
                                    + " slices ("
                                    + extended(str(times[0]), TIME_STR_LEN)
                                    + " - "
                                    + extended(str(times[-1]), TIME_STR_LEN)
                                    + ")"
                                )
                        print("\n")

                        connection.close()


def print_times_with_folder(dbs, print_times=False, pulse_number=None, run_number=None, showuri=False):
    for backend, pulses in dbs:
        printed_backend = False
        for pulse, details in pulses.items():
            file_path = pulse.rsplit("/", 1)[0]
            uripath = f"imas:{backend.lower()}?path={file_path}"
            if pulse_number is not None:
                if details[0][0] != pulse_number:
                    continue
            if run_number is not None:
                if details[0][1] != run_number:
                    continue
            if not printed_backend:
                print(TAB * 2 + "Backend: " + backend)
                printed_backend = True
            print(TAB * 3 + "Pulse " + extended(str(pulse), SHOT_STR_LEN))

            connargs = argparse.Namespace()
            connargs.backend = backend.upper()
            connargs.pulse = pulse
            connargs.run = details[0][1]
            connargs.user = args.user
            connargs.database = args.database
            connargs.version = args.version
            connargs.uri = uripath
            connection = DBMaster.get_connection(connargs)
            if showuri:
                print(TAB * 4 + '"' + connargs.uri + '"')
            if connection is None:
                continue
            available_ids_and_times = get_available_ids_and_times(connection)
            for idsname, times in available_ids_and_times:
                if times is not None:
                    if len(times) == 1 and np.isnan(times[0]):
                        print(
                            TAB * 15
                            + extended(idsname, IDSNAME_STR_LEN)
                            + ": "
                            + extended("?", SLICENUM_STR_LEN)
                            + " slices ( "
                            + "heterogeneous IDS )"
                        )
                    elif len(times) == 1 and times[0] == -np.inf:
                        print(
                            TAB * 15
                            + extended(idsname, IDSNAME_STR_LEN)
                            + ": "
                            + extended("?", SLICENUM_STR_LEN)
                            + " slices ( "
                            + "time independent IDS )"
                        )
                    elif print_times:

                        print(
                            TAB * 15
                            + extended(idsname, IDSNAME_STR_LEN)
                            + ": "
                            + extended(str(len(times)), SLICENUM_STR_LEN)
                            + " slices ("
                            + print_vector(times, TIME_STR_LEN)
                            + ")"
                        )
                    else:
                        times_0 = extended(str(times[0]), TIME_STR_LEN) if len(times) > 0 else ""
                        times_1 = extended(str(times[-1]), TIME_STR_LEN) if len(times) > 0 else ""
                        extended(str(times[0]), TIME_STR_LEN) if len(times) > 0 else ""
                        print(
                            TAB * 15
                            + extended(idsname, IDSNAME_STR_LEN)
                            + ": "
                            + extended(str(len(times)), SLICENUM_STR_LEN)
                            + " slices ("
                            + times_0
                            + " - "
                            + times_1
                            + ")"
                        )
            print("\n")

            connection.close()


description = """
This program lists existing IMAS databases.

Possible commands are:

- ``list <pulse number>``: list existing databases
- ``slices <pulse number> <run number>``: list existing databases, including number of timeslices and time
range for time-dependent IDSs
- ``times <pulse number> <run number>``: list existing databases, including number of timeslices their
time points for time-dependent IDSs
- ``databases``: list existing databases (with data versions)
- ``dataversions``: list existing dataversions (with databases)

If the optional arguments pulse number and run number are given,
only databases with these numbers will be shown. If no command is given,
the list command is performed. To see databases stored in the public imas database, use 'public' as the user name.

[Previously known as imasdbs]"""
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="dblist",
        formatter_class=RichHelpFormatter,
        description=Markdown(description, style="argparse.text"),
    )
    subparsers = parser.add_subparsers(help="sub-commands help")

    subparser_list = subparsers.add_parser("list", help="list databases")
    subparser_list.set_defaults(cmd="list")
    subparser_list.add_argument(
        "-c",
        "--compact",
        action="store_true",
        dest="compact",
        default=False,
        help="Compact/reduced output",
    )
    subparser_list.add_argument(
        "-M",
        "--lastModifiedDate",
        action="store_true",
        dest="timestamp",
        default=False,
        help="Show (and sort per) date of last modification of the runs",
    )
    subparser_list.add_argument("pulse", nargs="?", help="Pulse number", type=int)

    subparser_slices = subparsers.add_parser("slices", help="list slices")
    subparser_slices.set_defaults(cmd="slices")
    subparser_slices.add_argument("pulse", nargs="?", help="Pulse number", type=int)
    subparser_slices.add_argument("run", nargs="?", help="Run number", type=int)

    subparser_times = subparsers.add_parser("times", help="list times")
    subparser_times.set_defaults(cmd="times")
    subparser_times.add_argument("pulse", nargs="?", help="Pulse number", type=int)
    subparser_times.add_argument("run", nargs="?", help="Run number", type=int)

    subparser_databases = subparsers.add_parser("databases", help="print databases")
    subparser_databases.set_defaults(cmd="databases")

    subparser_data_versions = subparsers.add_parser("dataversions", help="print data versions")
    subparser_data_versions.set_defaults(cmd="dataversions")

    parser.add_argument(
        "-f",
        "--folder",
        dest="folder",
        default=None,
        help="Show data entries from specified folder (This options is useful to search data entries"
        " from folder recursively and it doesn't need structured layout to search"
        " data entries)\t\t(default=%(default)s)",
    )
    parser.add_argument(
        "-u",
        "--user",
        dest="user",
        default="public",  # os.getenv("USER"),
        help="Show databases of specified user \t\t(default=%(default)s)",
    )
    parser.add_argument(
        "-d",
        "--database",
        dest="database",
        default=None,
        help="Show only databases with specified name \t(default=%(default)s)",
    )
    parser.add_argument(
        "-v",
        "--version",
        dest="version",
        default=None,
        help="Show only databases for specified major data version \t(default=%(default)s)",
    )
    parser.add_argument(
        "--backend",
        dest="backend",
        default=None,
        help="Show databases written with given backend(s). \n\
    Comma-separated list of backends (Currently supported: mdsplus, hdf5). \
    By default all backends are shown. \t(default=%(default)s)",
    )
    parser.add_argument(
        "-showuri",
        "--showuri",
        action="store_true",
        help="Show uri",
    )
    parser.add_argument("positionalArgs", nargs="?", default=os.getcwd())

    args = parser.parse_args()
    # Default command if not provided
    if "cmd" not in args.__dict__:
        backends = args.backend.split(",") if args.backend else None
        if args.folder is not None:
            dbs = DBMaster.get_database_files_from_folder(args.folder, backends)
            print_list_from_folder(dbs)
        else:
            dbs = DBMaster.get_database_files(args.user, args.database, args.version, backends)
            print_list(dbs)
        exit(0)

    backends = args.backend.split(",") if args.backend else None
    if args.cmd in ["list", "slices", "times"]:
        if args.folder is not None:
            dbs = DBMaster.get_database_files_from_folder(args.folder, backends)
        else:
            dbs = DBMaster.get_database_files(args.user, args.database, args.version, backends)
    if args.cmd == "list":
        if args.showuri:
            if args.folder is not None:
                print_urilist_from_folder(dbs, args.pulse, args.compact, args.timestamp)
            else:
                print_urilist(dbs, args.pulse, args.compact, args.timestamp)
        else:
            if args.folder is not None:
                print_list_from_folder(dbs, args.pulse, args.compact, args.timestamp)
            else:
                print_list(dbs, args.pulse, args.compact, args.timestamp)

    if args.cmd == "slices":
        if args.folder is not None:
            print_times_with_folder(
                dbs,
                print_times=False,
                pulse_number=args.pulse,
                run_number=args.run,
                showuri=args.showuri,
            )
        else:
            print_times(
                dbs,
                args,
                print_times=False,
                pulse_number=args.pulse,
                run_number=args.run,
                showuri=args.showuri,
            )

    if args.cmd == "times":
        if args.folder is not None:
            print_times_with_folder(
                dbs,
                print_times=True,
                pulse_number=args.pulse,
                run_number=args.run,
                showuri=args.showuri,
            )
        else:
            print_times(
                dbs,
                args,
                print_times=True,
                pulse_number=args.pulse,
                run_number=args.run,
                showuri=args.showuri,
            )

    if args.cmd == "databases":
        databases_with_versions = DBMaster.get_databases_with_versions(args.user)
        if args.version is not None:
            databases_with_versions = filter(lambda x: (args.version in x[1]), databases_with_versions)
        for database_name, versions in databases_with_versions:
            print(extended(database_name, DATABASE_STR_LEN) + " " + print_vector(versions, VERSION_STR_LEN))

    if args.cmd == "dataversions":
        for version, database_names in DBMaster.get_versions_with_databases(args.user):
            print(f"{extended(version, VERSION_STR_LEN          )} {print_vector(database_names, DATABASE_STR_LEN)}")
