#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os
import sys

try:
    import imaspy as imas
except ImportError:
    import imas
from itertools import tee

import numpy as np
import rich
from rich.columns import Columns
from rich.console import Console
from rich.pretty import Pretty
from rich.table import Table
from rich.tree import Tree
from rich_argparse import RichHelpFormatter

from idstools.database import DBMaster
from idstools.utils.clihelper import (
    dbentry_parser,
    get_database_path,
    get_file_name,
    get_title,
    rcparam_parser,
)
from idstools.utils.idshelper import get_available_ids_and_times, parse_uri, partial_get
from idstools.utils.idslogger import setup_logger
from idstools.view.common import PlotCanvas

logger = setup_logger("module", stdout_level=logging.INFO)


def view_plot(
    ax,
    field,
    coordinate,
    field_name="",
    coordinate_name="",
    field_unit="",
    coordinate_unit="",
):
    if not isinstance(field, (imas.ids_primitive.IDSNumericArray, np.ndarray)):
        logger.error("Not a numeric array, Please select ids path")
        return
    data = field
    yfield_name = field_name
    if isinstance(coordinate, imas.ids_primitive.IDSNumericArray):
        if isinstance(coordinate, imas.ids_primitive.IDSPrimitive):
            coordinate_unit = f"{coordinate.metadata.units}"
            coordinate_name = coordinate.metadata.path
    if isinstance(field, imas.ids_primitive.IDSNumericArray):
        if not field.has_value:
            logger.error("Values are not present")
            return
        data = field.value
        field_unit = field.metadata.units
        yfield_name = field.metadata.name

        if coordinate and isinstance(coordinate, int):
            _coordinate = field.coordinates[coordinate]

            if isinstance(_coordinate, imas.ids_primitive.IDSPrimitive):
                if _coordinate.has_value is True:
                    coordinate = _coordinate
        elif coordinate and isinstance(coordinate, str):
            _coordinate = eval("ids." + coordinate)
            if isinstance(_coordinate, imas.ids_primitive.IDSPrimitive):
                if _coordinate.has_value is True:
                    coordinate = _coordinate
        else:
            for _coordinate in field.coordinates:
                if isinstance(_coordinate, imas.ids_primitive.IDSPrimitive):

                    if _coordinate.has_value is True:
                        coordinate_unit = f"{_coordinate.metadata.units}"
                        coordinate_name = _coordinate.metadata.name
                        coordinate = _coordinate
                        break
                    else:
                        continue

    if (
        coordinate is None
        or isinstance(coordinate, int)
        or (isinstance(coordinate, imas.ids_primitive.IDSNumericArray) and coordinate.has_value is False)
    ):
        logger.error("Coordinates are empty, creating default array, you can also provide custom coordinates")
        coordinate = np.arange(len(data))
        coordinate_name = "Index"
        coordinate_unit = "-"

    if len(data) < 5 and len(data.shape) == 1:
        ax.plot(coordinate, data, label=field_name, marker="o", color="red")

    elif len(data.shape) == 2:
        for j in range(data.shape[1]):
            if j == 0:
                ax.plot(coordinate, data[:, j], label=f"{field_name}")
            else:
                ax.plot(coordinate, data[:, j])
    else:
        ax.plot(coordinate, data, label=field_name)

    # ax.set_xlim(rho_tor_norm[0], rho_tor_norm[nrho - 1])
    ax.set_xlabel(f"{coordinate_name} [{coordinate_unit}]", labelpad=1)
    ax.set_ylabel(f"{yfield_name} [{field_unit}]", labelpad=0)

    # set legend
    # legx_pos = 1.35
    # legy_pos = 1.05
    ax.legend()


def print_tree(structure, hide_empty_nodes, compact, full_array, depth=None):
    if full_array:
        with np.printoptions(threshold=sys.maxsize, linewidth=1024, precision=4):
            rich.print(_make_tree(structure, hide_empty_nodes, compact, depth=depth))
    else:
        with np.printoptions(threshold=5, linewidth=1024, precision=4):
            rich.print(_make_tree(structure, hide_empty_nodes, compact, depth=depth))


def _make_tree(structure, hide_empty_nodes, compact, *, tree=None, depth=None, current_depth=0):
    """Build the ``rich.tree.Tree`` for display in :py:meth:`print_tree`.

    Args:
        structure: IDS structure to add to the tree
        hide_empty_nodes: Show or hide nodes without value.
        compact: Compact display mode
        depth: Maximum depth to traverse (None for unlimited)
        current_depth: Current traversal depth (internal parameter)

    Keyword Args:
        tree: If provided, child items will be added to this Tree object. Otherwise a
            new Tree is constructed.
    """
    if depth is not None and current_depth > depth:
        return tree

    if tree is None:
        if isinstance(structure, np.ndarray):
            tree = Tree(f"[magenta]{type(structure)}")
            if structure.size == 0:
                value = "[bright_black]"
                txt = f"[grey62]numpy.ndarray|{structure.shape}|{structure.dtype}[/]"
            else:
                value = Pretty(structure)
                txt = f"[bright_yellow]numpy.ndarray|{structure.shape}|{structure.dtype}[/]:"

            group = Columns([txt, value])
            if compact:
                txt = f"[bright_yellow]numpy.ndarray|{structure.shape}|{structure.dtype}[/]"
                group = Columns([txt])
            tree.add(group)
            return tree
        elif isinstance(structure, imas.ids_primitive.IDSNumericArray):
            tree = Tree(f"[magenta]{type(structure)}")
            if structure.size == 0:
                value = "[bright_black]"
                txt = f"[grey62]numpy.ndarray|{structure.shape}|{structure.dtype}[/]"
            else:
                value = Pretty(structure)
                txt = f"[bright_yellow]numpy.ndarray|{structure.shape}|{structure.dtype}[/]:"

            group = Columns([txt, value])
            if compact:
                txt = f"[bright_yellow]numpy.ndarray|{structure.shape}|{structure.dtype}[/]"
                group = Columns([txt])
            tree.add(group)
            return tree
        elif isinstance(structure, np.float64) or isinstance(structure, imas.ids_primitive.IDSPrimitive):
            return structure
        else:
            tree = Tree(f"[magenta]{structure.metadata.name}")

    iterator = structure
    if hide_empty_nodes and isinstance(structure, imas.ids_structure.IDSStructure):
        iterator = structure.iter_nonempty_(accept_lazy=True)
    for counter, child in enumerate(iterator):
        if isinstance(child, (np.ndarray)):
            if child.size == 0:
                value = "[bright_black]"
                txt = f"[grey62]numpy.ndarray|{child.shape}|{child.dtype}|{counter}][/]"
            else:
                value = Pretty(child)
                txt = f"[bright_yellow]numpy.ndarray|{child.shape}|{child.dtype}|{counter}][/]:"

            group = Columns([txt, value])
            if compact:
                txt = f"[bright_yellow]numpy.ndarray|{child.shape}|{child.dtype}|{counter}][/]"
                group = Columns([txt])
            tree.add(group)
        elif isinstance(child, imas.ids_primitive.IDSPrimitive):
            if not child.has_value:
                value = "[bright_black]"
                txt = f"[grey62]{child.metadata.name}[/]"
            else:
                value = Pretty(child.value)
                txt = f"[bright_yellow]{child.metadata.name}[/]:"

            group = Columns([txt, value])
            if compact:
                txt = f"[bright_yellow]{child.metadata.name}[/]"
                group = Columns([txt])
            tree.add(group)
        else:
            if isinstance(child, imas.ids_structure.IDSStructure):
                txt = f"[magenta]{child._path}[/]"
                # check if structure is not empty
                iterator, iterator_copy = tee(child.iter_nonempty_(accept_lazy=True))
                try:
                    next(iterator_copy)
                except StopIteration:
                    pass
                else:
                    ntree = tree.add(txt)
                    _make_tree(
                        child, hide_empty_nodes, compact, tree=ntree, depth=depth, current_depth=current_depth + 1
                    )
            elif isinstance(child, imas.ids_struct_array.IDSStructArray):
                ntree = tree
                if not child.has_value:
                    tree.add(f"[magenta]{child._path}[][/]")
                _make_tree(child, hide_empty_nodes, compact, tree=ntree, depth=depth, current_depth=current_depth + 1)

    return tree


def _make_dict_tree(structure, hide_empty_nodes, compact, *, tree=None):
    """Build the ``rich.tree.Tree`` for display in :py:meth:`print_tree`.

    Args:
        structure: IDS structure to add to the tree
        hide_empty_nodes: Show or hide nodes without value.

    Keyword Args:
        tree: If provided, child items will be added to this Tree object. Otherwise a
            new Tree is constructed.
    """

    if tree is None:
        tree = {}

    iterator = structure
    if hide_empty_nodes and isinstance(structure, imas.ids_structure.IDSStructure):
        iterator = structure.iter_nonempty_(accept_lazy=True)

    for counter, child in enumerate(iterator):
        if isinstance(child, (np.ndarray)):
            tree = child.tolist()
        elif isinstance(child, imas.ids_primitive.IDSPrimitive):
            data = child.value
            if isinstance(data, np.ndarray):
                data = data.tolist()
            tree[child.metadata.name] = data
        else:
            if isinstance(child, imas.ids_structure.IDSStructure):
                ntree = {}
                tree[child.metadata.name] = ntree
                _make_dict_tree(child, hide_empty_nodes, compact, tree=ntree)
            elif isinstance(child, imas.ids_struct_array.IDSStructArray):
                nlist = []
                tree[child.metadata.name] = nlist
                for counter, ids_structure in enumerate(child):
                    ntree = {}
                    nlist.append(ntree)
                    nlist.append(_make_dict_tree(ids_structure, hide_empty_nodes, compact, tree=ntree))
    return tree


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="""Prints content of an IDS onto the terminal. 
        The selected IDS is given in the URI #fragment, e.g. 
        e.g. idsprint -u 'imas:mdsplus?user=public;pulse=122525;run=1;database=ITER;version=3#equilibrium'
        It can also print the content of a field or substructure, by modifying the fragment, e.g.
        '#equilibrium/time_slice[0]/profiles_2d[0].r'.
        [idsprint was previously known as idsdump and idsdumppath]""",
        formatter_class=RichHelpFormatter,
        parents=[dbentry_parser, rcparam_parser],
    )
    parser.add_argument(
        "-e",
        "--show-empty",
        action="store_true",
        dest="show_empty",
        help="Show empty fields of ids",
    )
    parser.add_argument(
        "-f",
        "--full",
        action="store_true",
        help="Print all array elements (can be slow for large data)",
    )
    parser.add_argument(
        "-c",
        "--compact",
        action="store_true",
        help="Print only names which has data",
    )
    parser.add_argument(
        "-d",
        "--depth",
        type=int,
        default=None,
        help="Maximum depth to traverse when printing tree structure (None for unlimited)",
    )
    parser.add_argument(
        "-i",
        "--inspect",
        action="store_true",
        help="Print child nodes information and metadata",
    )
    parser.add_argument("-t", "--time", help="Time", required=False, type=float, default=None)
    parser.add_argument(
        "--export",
        action="store_true",
        help=("export ids data to use in other format"),
    )
    parser.add_argument(
        "--export-type",
        type=str,
        default="mat",
        help=("type of export mat, json"),
    )
    parser.add_argument(
        "-p",
        "--plot",
        action="store_true",
        help="plot 1d arrays from leaf nodes",
    )
    parser.add_argument(
        "--coordinate",
        help="Provide custom coordinate if required, provide index from coordinates array or ids field",
        default=None,
    )
    parser.add_argument(
        "--save",
        help="Save figure at default location",
        action="store_true",
    )
    parser.add_argument(
        "--directory",
        help="Directory to save the figure",
        default=None,
    )
    args = parser.parse_args()
    time_slice = 0.0
    if args.time:
        time_slice = float(args.time)
    uri_dict = parse_uri(args.uri)
    uri_part = uri_dict["uri_part"]
    occurrence = uri_dict["occurrence"] or 0
    ids_name = uri_dict["ids_name"]
    ids_path = uri_dict["ids_path"]
    original_ids_path = ids_path
    args.uri = uri_part
    connection = DBMaster.get_connection(args)
    if connection is None:
        logger.error("Please provide valid URI")
        exit(1)
    if ids_name == "":
        logger.info("Please provide ids name at the end of uri with # character")
        logger.info("Following are available IDSes")
        available_ids_and_times = get_available_ids_and_times(connection)
        question_string = "?"
        console = Console()
        table = Table(title="List of IDSes")
        table.add_column("IDS", style="magenta")
        table.add_column("SLICES", style="green")
        table.add_column("TIME", style="green")
        for ids_name, time_array in available_ids_and_times:
            if len(time_array) == 1 and np.isnan(time_array[0]):
                value = f"{question_string}"
                type = "heterogeneous IDS"
            elif len(time_array) == 1 and time_array[0] == -np.inf:
                value = f"{question_string}"
                type = "time independent IDS"
            else:
                value = f"{len(time_array)}"
                type = time_array

            # Handle both string and list types safely
            if isinstance(type, str):
                type_str = type
            elif len(type) > 0:
                type_str = str(type[0]) + ",...," + str(type[-1])
            else:
                type_str = "empty"

            table.add_row(ids_name, value, type_str)
        console.print(table)
        connection.close()
        exit(1)

    if ids_path is None:
        if args.time:
            ids = connection.get_slice(
                ids_name,
                occurrence=occurrence,
                time_requested=time_slice,
                autoconvert=False,
                interpolation_method=imas.ids_defs.CLOSEST_INTERP,
            )

        else:
            ids = connection.get(ids_name, occurrence=occurrence, autoconvert=False, ignore_unknown_dd_version=True)
        if args.dd_update:
            ids = imas.convert_ids(ids, connection.factory.version)
        if args.inspect:
            imas.util.inspect(ids, hide_empty_nodes=not args.show_empty)
        elif args.export:
            fields = _make_dict_tree(ids, hide_empty_nodes=not args.show_empty, compact=args.compact)
            root_dict = {}
            root_dict[ids_name] = fields
            if args.export_type == "mat":
                fname = f"{ids_name}.mat"
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                from scipy.io import savemat

                savemat(fname, root_dict, long_field_names=True)
                logger.info(f"MAT file '{fname}' created successfully!")
            elif args.export_type == "json":
                fname = f"{ids_name}.json"
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                import json

                with open(fname, "w") as f:
                    json.dump(root_dict, f, indent=4)
                logger.info(f"JSON file '{fname}' created successfully!")
            elif args.export_type == "yaml":
                fname = f"{ids_name}.yaml"
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                import yaml

                with open(fname, "w") as f:
                    yaml.dump(root_dict, f)
                logger.info(f"YAML file '{fname}' created successfully!")
            else:
                logger.error(f"Please provide valid export type")
        else:
            print_tree(
                ids, hide_empty_nodes=not args.show_empty, compact=args.compact, full_array=args.full, depth=args.depth
            )
    else:
        lazy_value = True
        if args.dd_update:
            lazy_value = False
        if args.time:
            ids = connection.get_slice(
                ids_name,
                occurrence=occurrence,
                autoconvert=False,
                ignore_unknown_dd_version=True,
                lazy=lazy_value,
                time_requested=time_slice,
                interpolation_method=imas.ids_defs.CLOSEST_INTERP,
            )
        else:
            ids = connection.get(
                ids_name, occurrence=occurrence, autoconvert=False, ignore_unknown_dd_version=True, lazy=lazy_value
            )
        if args.dd_update:
            ids = imas.convert_ids(ids, connection.factory.version)
        coordinate = None
        node_unit = ""
        coordinate_name = ""
        coordinate_unit = ""
        if args.coordinate:
            coordinate = args.coordinate  # provide custom coordiante like "profiles_1d[0]/grid/psi"
            if coordinate.isdigit():
                coordinate = int(coordinate)
            else:
                try:
                    coordinate = coordinate.replace("(", "[").replace(")", "]").replace("/", ".")
                    coordinate = eval("ids." + coordinate)
                except Exception as e:
                    logger.error(f"{coordinate} path does not exist, hint: check length of arrays, detailed error {e}")

        ids_path = ids_path.replace("(", "[").replace(")", "]").replace("/", ".")
        if ":" in ids_path:
            node, _coordinate, node_unit, coordinate_unit = partial_get(ids, ids_path)

            if _coordinate.shape[0] != node.shape[0]:
                node = np.transpose(node)
            if not args.coordinate and _coordinate is not None:
                coordinate = _coordinate
        else:
            try:
                node = eval("ids." + ids_path)
                # reevaluate if struct or struct_array
                if isinstance(node, (imas.ids_structure.IDSStructure, imas.ids_struct_array.IDSStructArray)):
                    ids = connection.get(ids_name, occurrence, autoconvert=False)
                    node = eval("ids." + ids_path)
            except Exception as e:
                logger.error(f"{ids_path} path does not exist, hint: check length of arrays, detailed error {e}")
                node = np.array([]).reshape(
                    0,
                )

        if args.inspect:
            imas.util.inspect(node, hide_empty_nodes=not args.show_empty)
        elif args.export:
            fields = _make_dict_tree(ids, hide_empty_nodes=not args.show_empty, compact=args.compact)
            root_dict = {}
            root_dict[ids_name] = fields
            if args.export_type == "mat":
                fname = f"{ids_name}.mat"
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                from scipy.io import savemat

                savemat(fname, root_dict, long_field_names=True)
                logger.info(f"MAT file '{fname}' created successfully!")
            elif args.export_type == "json":
                fname = f"{ids_name}.json"
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                import json

                with open(fname, "w") as f:
                    json.dump(root_dict, f, indent=4)
                logger.info(f"JSON file '{fname}' created successfully!")
            elif args.export_type == "yaml":
                fname = f"{ids_name}.yaml"
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                import yaml

                with open(fname, "w") as f:
                    yaml.dump(root_dict, f)
                logger.info(f"YAML file '{fname}' created successfully!")
            else:
                logger.error(f"Please provide valid export type")
        elif args.plot:
            if isinstance(node, (imas.ids_structure.IDSStructure, imas.ids_struct_array.IDSStructArray)):
                logger.error("Can not plot data, fragment is structure. inspect data with -i option")
                exit(0)
            if len(node) == 0:
                logger.error("Can not plot data, no values present in the array. inspect data with -i option")
                exit(0)

            if isinstance(node[0], (imas.ids_structure.IDSStructure, imas.ids_struct_array.IDSStructArray)):
                logger.error("Can not plot data, fragment values are structures. inspect data with -i option")
                exit(0)

            canvas = PlotCanvas(1, 1)
            canvas.update_style(args.rc)
            ax = canvas.add_axes(title="", xlabel="", row=0, col=0)
            view_plot(
                ax,
                node,
                coordinate,
                field_name=f"{ids_name}/{original_ids_path}",
                coordinate_name=coordinate_name,
                field_unit=node_unit,
                coordinate_unit=coordinate_unit,
            )

            canvas.set_text(text=f"{get_database_path(args, time_value=None)}")

            canvas.fig.subplots_adjust(top=0.916, bottom=0.09, left=0.044, right=0.953, hspace=0.287, wspace=0.2)
            canvas.get_current_fig_manager().set_window_title(os.path.basename(__file__))
            canvas.fig.suptitle(get_title(args, f"{ids_name}/{original_ids_path}", None))

            if args.save:
                fname = get_file_name(args, os.path.basename(__file__), None)
                if args.directory:
                    if not os.path.exists(args.directory):
                        os.makedirs(args.directory)
                    fname = os.path.join(args.directory, fname)
                canvas.save(fname)
            else:
                canvas.show()

        else:
            print_tree(
                node, hide_empty_nodes=not args.show_empty, compact=args.compact, full_array=args.full, depth=args.depth
            )
    connection.close()
