#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Small utility script to profile performance of some AL operations on any dataset."""

import argparse
import os
import statistics as stat
import sys

from rich_argparse import RichHelpFormatter

from idstools import idsperf
from idstools.database import DBMaster
from idstools.utils.clihelper import uri_parser
from idstools.utils.idshelper import get_available_ids_and_times
from idstools.utils.idslogger import setup_logger

logger = setup_logger("module")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="This script gives simple timing and performance information for different types of"
        "operations on IDS data with the IMAS Python Access Layer (get/get_slice/put depending on selected options)",
        parents=[uri_parser],
        formatter_class=RichHelpFormatter,
    )
    timegroup = parser.add_mutually_exclusive_group()
    timegroup.add_argument(
        "-t",
        "--slice-time",
        type=float,
        nargs="+",
        help="Use get_slice with selected time(s)",
    )
    timegroup.add_argument(
        "-a",
        "--all_slices",
        action="store_true",
        help="Use get_slice with all available times",
    )
    parser.add_argument(
        "-m",
        "--memory-backend",
        action="store_true",
        help="Use MEMORY_BACKEND for this test (involve reading from file and loading in memory first)",
    )
    parser.add_argument(
        "--repeat",
        type=int,
        default=1,
        help="Repeat timing n times (default: %(default)s)",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Verbose mode, prints additional information",
    )
    parser.add_argument("--show-stats", action="store_true", help="Print addition stats for timings")
    parser.add_argument(
        "--profile",
        action="store_true",
        help="Also do full profile of the selected operation, via cProfile",
    )
    parser.add_argument("-o", "--output-run", type=int, help="Output run number for checking perf of put")
    parser.add_argument(
        "ids",
        nargs="*",
        type=str,
        help=(
            "IDS name(s) (leave empty to select all IDSs with default occurrence, or"
            'append "/n" to copy a specific occurrence "n")'
        ),
    )
    parser.add_argument(
        "-uri-out",
        "--uri-out",
        type=str,
        help="uri out \t\t(default=%(default)s)",
    )
    args = parser.parse_args()

    # hack to get underlying data dictionary version used by idses
    src_connection = DBMaster.get_connection(args)
    if src_connection is None:
        print("Error opening source pulse! Please check existence.")
        sys.exit()
    if args.ids == []:
        args.ids = [ids for ids, _ in get_available_ids_and_times(src_connection)]
    src_dd_version = src_connection.dd_version
    for idsname in args.ids:
        nameocc = idsname.split("/")
        idsname = nameocc[0]
        occ = 0
        if len(nameocc) == 2:
            occ = int(nameocc[1])
        try:
            _dummy = src_connection.get(idsname, occurrence=occ, lazy=True, autoconvert=False)
            src_dd_version = _dummy.ids_properties.version_put.data_dictionary.value
            break
        except Exception as e:
            logger.error(f"Exception occurred, detailed error {e}")
    src_connection.close()

    # For memory backend need to read the full IDS and load it into memory
    if args.memory_backend:
        print("First import data into memory...", file=sys.stderr)
        # sys.exit(1)
        dbref = DBMaster.get_connection(args)
        if dbref is None:
            print("----> Aborted.")
            exit(1)

        args.backend = "MEMORY"
        if "uri" in args.__dict__ and args.uri is not None:
            start_index = args.uri.index(":") + 1
            end_index = args.uri.index("?")

            # Replace the substring between ':' and '?' with 'memory'
            args.uri = args.uri[:start_index] + "memory" + args.uri[end_index:]
        dbm = DBMaster.create_connection(args, src_dd_version)
        if dbm is None:
            print("----> Aborted.")
            exit(1)

        for idsname in args.ids:
            nameocc = idsname.split("/")
            idsname = nameocc[0]
            occ = 0
            if len(nameocc) == 2:
                occ = int(nameocc[1])
            dbm.put(dbref.get(idsname, occ, autoconvert=False))
        args.mode = "r"
    db = DBMaster.get_connection(args)
    if db is None:
        print("----> Aborted.")
        exit(1)

    idsobj = None
    dbout = None

    # Set command
    dbout_args = argparse.Namespace()
    if args.uri_out or args.output_run:
        if args.uri_out:
            dbout_args.uri = args.uri_out
            if args.verbose:
                print("Using given uri_out for checking put performance")
        elif args.output_run:
            import re

            dbout_args.uri = args.uri
            dbout_args.uri = re.sub(r"user=[^;]+", f"user={os.environ['USER']}", dbout_args.uri)
            dbout_args.uri = re.sub(r"run=[^;]+", f"run={args.output_run}", dbout_args.uri)

        dbout = DBMaster.create_connection(dbout_args, src_dd_version)
        if dbout is None:
            print("----> Aborted.")
            exit(1)
        if args.verbose:
            print("Create pulse for put operation")

    for idsname in args.ids:
        nameocc = idsname.split("/")
        idsname = nameocc[0]
        occ = 0
        if len(nameocc) == 2:
            occ = int(nameocc[1])

        if args.all_slices:
            # TODO: might not be relevant if homogeneous_time != 1
            ids = db.get(idsname, occurrence=occ, autoconvert=False)
            times = (ids.time).tolist()
        else:
            times = args.slice_time

        timings = idsperf.get_timings(
            db,
            idsname,
            occ=occ,
            times=times,
            repeat=args.repeat,
            verbose=args.verbose,
            profile=args.profile,
            dd_update=False,
        )

        if args.show_stats and args.repeat > 1:
            print(f"All timings  = {timings}")
            print(f"Mean         = {stat.mean(timings)}")
            print(f"Standard dev = {stat.stdev(timings)}")
            print(f"Variance     = {stat.variance(timings)}")
        print(f"{idsname}/{occ} best time = {min(timings)} s")
