#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import logging
import sys

try:
    import imaspy as imas
except ImportError:
    import imas
from rich_argparse import RichHelpFormatter

from idstools.database import DBMaster
from idstools.utils.idshelper import (
    get_available_ids_and_occurrences,
    resample_indices,
    resample_times,
)
from idstools.utils.idslogger import setup_logger

slicing_methods = {
    "CLOSEST": imas.ids_defs.CLOSEST_INTERP,
    "PREVIOUS": imas.ids_defs.PREVIOUS_INTERP,
    "LINEAR": imas.ids_defs.LINEAR_INTERP,
}
logger = setup_logger("module", stdout_level=logging.INFO)
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Resample IDSs from a data-entry and save them into another data-entry "
        "[previously known as ids_resampling]",
        formatter_class=RichHelpFormatter,
    )
    parser.add_argument(
        "-s",
        "--src",
        type=str,
        required=True,
        help="source uri",
    )
    parser.add_argument(
        "-d",
        "--dest",
        type=str,
        required=True,
        help="destination uri",
    )
    parser.add_argument(
        "--dd-update",
        action="store_true",
        help=(
            "Convert IDS to the default version of the data dictionary if enabled"
            "otherwise, use the original IDS stored on disk."
        ),
    )
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--index-range",
        type=str,
        help="Specified range of slices index as 'start,stop,step'. If omitted, start=0, stop=len(timebase),step=1,"
        " e.g. '0,,10' to keep 1 every 10 slices. Works only for IDS with homogeneous timebase. (Default)",
    )
    group.add_argument(
        "--time-range",
        type=str,
        help="Specified range of times as 'start,stop,step'. If omitted, start=time[0], stop=time[-1]), while "
        "omitting step will keep of slices between start and stop, e.g. '10.,50.,' to keep all times between 10. "
        "and 50. secondes). Works only for IDS with homogeneous timebase unless all three values are specified.",
    )
    parser.add_argument(
        "-m",
        "--slicingmethod",
        type=str,
        default="CLOSEST",
        choices=["CLOSEST", "PREVIOUS", "LINEAR"],
        help="Slicing method \t(default=%(default)s)",
    )
    parser.add_argument(
        "-f",
        "--force",
        action="store_true",
        help="Force the creation of destination data-entry (existing data will be lost)",
    )
    parser.add_argument(
        "ids",
        nargs="*",
        type=str,
        help="IDSs to resample (leave empty to resample all)",
    )
    args = parser.parse_args()

    if args.src is not None and args.dest is not None:
        if args.src == args.dest:
            print("Can not use the same uri as source and destination!")
            exit()

    inputargs = argparse.Namespace()
    inputargs.uri = args.src
    src = DBMaster.get_connection(inputargs)
    availables_ids = get_available_ids_and_occurrences(src)
    if args.ids == []:
        args.ids = availables_ids
    else:
        args.ids = [(ids, 0) for ids in args.ids]

    if src is None:
        print("Error opening source pulse! Please check existence.")
        sys.exit()
    # hack to get underlying data dictionary version used by idses
    src_dd_version = src.dd_version
    if not args.dd_update:
        for idsname, occurrence in availables_ids:
            try:
                _dummy = src.get(idsname, lazy=True, autoconvert=False)
                src_dd_version = _dummy.ids_properties.version_put.data_dictionary.value
                break
            except Exception as e:
                logger.error(f"Exception occurred detailed description : {e}")

    tmp = imas.DBEntry("imas:memory?path=tmp/1/1", "w")

    if tmp is None:
        print("Error creating temporary data-entry in memory!", file=sys.stderr)
        sys.exit(1)

    outputargs = argparse.Namespace()
    outputargs.mode = "a"
    outputargs.uri = args.dest
    dest = None
    if args.force:
        # CREATE DEST
        outputargs.mode = "w"
        if outputargs.uri != "" and outputargs.uri is not None:
            dest = imas.DBEntry(outputargs.uri, outputargs.mode, dd_version=src_dd_version)
        if dest is None:
            logger.error("Error creating destination pulse! Please check parameters and permissions.")
            exit(1)
    else:
        if outputargs.uri != "" and outputargs.uri is not None:
            dest = imas.DBEntry(outputargs.uri, outputargs.mode, dd_version=src_dd_version)

        if dest is None:
            logger.error("Error creating destination pulse! Please check parameters and permissions.")
            exit(1)

    if dest is None:
        print("Error creating pulse! Please check existence.")
        sys.exit()

    for ids_name, occurrence in args.ids:
        idsobj = None
        try:
            idsobj = src.get(ids_name, autoconvert=False)
        except Exception as e:
            logger.error(f"Exception occurred, detailed error {e}")
        if idsobj:
            homogeneous_time = idsobj.ids_properties.homogeneous_time
            if homogeneous_time == 1:
                if args.time_range is None:  # index range
                    if args.index_range is None:
                        args.index_range = ",,"
                    time_range = args.index_range.split(",")
                    start = int(time_range[0]) if time_range[0] != "" else 0
                    stop = int(time_range[1]) if time_range[1] != "" else len(idsobj.time)
                    step = int(time_range[2]) if time_range[2] != "" else 1
                    print(f"resampling indices :{ids_name} with {start}, {stop}, {step}")
                    if start == stop:
                        print(f"Please provide range. start={start} and stop={stop} are the same", file=sys.stderr)
                        exit(0)
                    resample_indices(src, tmp, ids_name, start=start, stop=stop, step=step)
                else:
                    time_range = args.time_range.split(",")
                    start = float(time_range[0]) if time_range[0] != "" else idsobj.time[0]
                    stop = float(time_range[1]) if time_range[1] != "" else idsobj.time[-1]
                    step = float(time_range[2]) if time_range[2] != "" else 1.0
                    print(f"resampling times :{ids_name} with {start}, {stop}, {step}")
                    if start == stop:
                        print(f"Please provide range. start={start} and stop={stop} are the same", file=sys.stderr)
                        exit(0)
                    try:
                        resample_times(
                            src,
                            tmp,
                            ids_name,
                            occurrence=0,
                            start=start,
                            stop=stop,
                            step=step,
                            interpolation_method=slicing_methods[args.slicingmethod],
                        )
                    except Exception as e:  # noqa: F841
                        print(f"Error resampling times for {ids_name} : {e}")
                        continue
            elif homogeneous_time == 0:
                if args.time_range is not None:
                    time_range = args.time_range.split(",")
                    if time_range[0] != "" and time_range[1] != "" and time_range[2] != "":
                        print(f"resampling times :{ids_name}")
                        try:
                            resample_times(
                                src,
                                tmp,
                                ids_name,
                                occurrence=0,
                                start=time_range[0],
                                stop=time_range[1],
                                step=time_range[2],
                                interpolation_method=args.interpolation_method,
                            )
                        except Exception as e:
                            print(f"Error resampling times for {ids_name} : {e}")
                            continue
                    else:
                        print(
                            f"Skipping IDS {ids_name} because its homogeneous_time={homogeneous_time}",
                            file=sys.stderr,
                        )
                        continue
                else:
                    print(
                        f"Skipping IDS {ids_name} because its homogeneous_time={homogeneous_time}",
                        file=sys.stderr,
                    )
                    continue
            else:
                # print(f'Skipping IDS {idsname} because its homogeneous_time={homogeneous}',file=sys.stderr)
                continue

            dest.put(tmp.get(ids_name, autoconvert=False))
    if dest:
        dest.close()
    if src:
        src.close()
