#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Return values of an IDS in all data entries of a database

import argparse
import sys
from pathlib import Path

import numpy as np

try:
    import imaspy as imas
except ImportError:
    import imas
import operator

from rich.console import Console
from rich.table import Table
from rich.text import Text
from rich_argparse import RichHelpFormatter

from idstools.database import DBMaster
from idstools.utils.clihelper import get_backend_id, imas_parser
from idstools.utils.idshelper import get_quantities_from_pulses

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extracts given quantities from all data entries of "
        "a given database [Previously known as idstools/db_extractor.py]"
        "e.g. dbscraper 'core_profiles/profiles_1d(0)/electrons/temperature' "
        "--list-count 5 --query 'x[0] > 20000'",
        parents=[imas_parser],
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument(
        "idspath",
        type=str,
        nargs="+",
        help="One or more IDS paths (starting with IDS name) to the desired data to be collected, e.g. "
        "equilibrium/time core_profiles/profiles_1d(0)/electrons/temperature",
    )
    parser.add_argument(
        "--query",
        type=str,
        required=False,
        help="Query expression to evaluate on IDS field"
        "x1 referes to first idspath in the list, x2 to second, etc. "
        "Examples: 'x1 > 0.5', 'mean(x1) > 10000', 'x1[0] == 1.0', 'any(x1 > 5000)'",
    )
    parser.add_argument(
        "--saveas",
        type=str,
        help="File in which to store the results of this query, in csv format",
    )
    parser.add_argument(
        "--status",
        type=str,
        help="Will list only data entries with specified status (if such metadata is available)",
    )
    parser.add_argument(
        "--list-count",
        type=int,
        default=0,
        help="number of entries user needs to display",
    )
    parser.add_argument(
        "--dd-update",
        action="store_true",
        help=(
            "Convert IDS to the default version of the data dictionary if enabled"
            "otherwise, use the original IDS stored on disk."
        ),
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Verbose mode, prints additional information",
    )
    args = parser.parse_args()
    backend = get_backend_id(args.backend)
    dbmaster = DBMaster()
    pulses = None
    if backend == imas.ids_defs.MDSPLUS_BACKEND:
        pulses = dbmaster.get_mds_plus_pulses(args.user, args.database, args.version, status=args.status)
    elif backend == imas.ids_defs.HDF5_BACKEND:
        pulses = dbmaster.get_hdf5_pulses(args.user, args.database, args.version)
    else:
        print(f"Functionality not yet implemented for backend {args.backend}")
        sys.exit()

    if args.verbose:
        import pprint

        pprint.pprint(pulses)

    if pulses is not None:
        df = get_quantities_from_pulses(
            args.idspath, tuple(pulses), args.list_count, args.verbose, query=args.query, dd_update=args.dd_update
        )

        if args.saveas:
            if not Path(args.saveas).parent.exists():
                raise FileNotFoundError(
                    "The path provided does not exist or has no such database file or directory. "
                    "Please check spelling."
                )
            df.to_csv(args.saveas, na_rep="None", index=True, header=True)
        else:
            np.set_printoptions(threshold=4)
            column_styles = {"URI": "green"}
            # Add style for each idspath
            for idspath in args.idspath:
                column_styles[idspath] = "yellow"

            console = Console()
            table = Table(show_header=True, header_style="bold magenta")
            # Add columns - URI and all idspaths
            table.add_column("URI", style="green", no_wrap=False)
            for idspath in args.idspath:
                idspath_text = Text(str(idspath), overflow="fold")
                table.add_column(idspath_text, style="yellow")
            if args.query:
                query_text = Text(str(args.query), overflow="fold")
                table.add_column(query_text, style="red")
            # Add rows for each entry
            for index, row in df.iterrows():
                uri_text = Text(str(row["URI"]), overflow="fold")
                row_values = [uri_text]
                for idspath in args.idspath:
                    val = row[idspath]
                    # If val is numpy array, convert to full string
                    if isinstance(val, np.ndarray):
                        if val.size <= 10:
                            val_str = np.array2string(val, precision=3, suppress_small=True)
                        else:
                            # Check if array is numeric before calculating mean
                            if np.issubdtype(val.dtype, np.number):
                                val_str = f"{val.shape} array (mean={np.mean(val):.2f})"
                            else:
                                val_str = f"{val.shape} array (dtype={val.dtype})"
                        row_values.append(val_str)
                    else:
                        row_values.append(str(val))
                if args.query:
                    val = row[args.query]
                    if isinstance(val, np.ndarray):
                        if val.size <= 10:
                            val_str = np.array2string(val, precision=3, suppress_small=True)
                        else:
                            # Check if array is numeric before calculating mean
                            if np.issubdtype(val.dtype, np.number):
                                val_str = f"{val.shape} array (mean={np.mean(val):.2f})"
                            else:
                                val_str = f"{val.shape} array (dtype={val.dtype})"
                        row_values.append(val_str)
                    else:
                        row_values.append(str(val))
                table.add_row(*row_values)
                table.add_section()

            console.print(table)
