#!python
# pylint: disable=no-member,ungrouped-imports,invalid-name,redefined-outer-name

import argparse
import inspect
import os
import sys
from datetime import datetime
from pathlib import Path

import jax
import numpy as np
import optax
import yaml
from matplotlib.pyplot import close

from nabu import Likelihood, PosteriorTransform
from nabu.flow import (
    RationalQuadraticSpline,
    available_activations,
    coupling_flow,
    masked_autoregressive_flow,
)
from nabu.plotting import summary_plot
from nabu.tensorboard import SummaryWriter
from nabu.transform_base import (
    standardise_between_negone_and_one,
    standardise_between_zero_and_one,
    standardise_dalitz,
    standardise_mean_std,
    standardise_median_quantile,
)

standardisation_methods = {
    "between_negone_and_one": standardise_between_negone_and_one,
    "between_zero_and_one": standardise_between_zero_and_one,
    "dalitz": standardise_dalitz,
    "mean_std": standardise_mean_std,
    "median_quantile": standardise_median_quantile,
}


def main(args: argparse.Namespace) -> None:
    """Main function to fit the data to the model."""
    data = np.load(args.DATAPATH)

    p_trans = PosteriorTransform()
    if args.STDCONFIG is not None:
        with open(args.STDCONFIG, "r", encoding="utf-8") as f:
            std_conf = yaml.safe_load(f)
        try:
            new_data = data["dataset"]
        except KeyError as exc:
            raise KeyError(
                "Compressed NumPy file is expected to have the dataset under `dataset` keyword."
            ) from exc
        assert (
            len(new_data.shape) == 2
        ), f"Invalid dimensionality, expected 2 got {len(new_data.shape)}"
        p_trans, new_data = standardisation_methods[std_conf["method"]](
            new_data, **std_conf["keywordargs"]
        )
        np.random.shuffle(new_data)
        data = {}
        # pylint: disable=unbalanced-tuple-unpacking
        data["X_train"], data["X_test"] = np.split(
            new_data, [int(data.shape[0] * std_conf["train_frac"])]
        )
        np.savez_compressed(
            os.path.join(str(Path(args.STDCONFIG).parent), "standardised_dataset.npz"),
            **data,
        )

    X_train = data["X_train"]
    dim = X_train.shape[1]
    if 0.0 < args.DATAFRAC < 1.0:
        X_train = X_train[
            np.random.choice(
                np.arange(X_train.shape[0]), size=int(X_train.shape[0] * args.DATAFRAC)
            )
        ]
    print(
        f"Training for {X_train.shape[0]}, {dim}D samples. "
        "Pour yourself a cup of coffee; this will take some time..."
    )

    if args.TRANS == "affine":
        transformer = None
    else:
        transformer = RationalQuadraticSpline(
            knots=args.KNOTS,
            interval=args.INTERVAL[0]
            if len(args.INTERVAL) == 1
            else tuple(args.INTERVAL),
        )

    while True:
        # check if there are any completed scans
        # important in case of a parallel scans running at the same time
        if any("-COMPLETE" in fl for fl in os.listdir(os.path.split(args.OUTPATH)[0])):
            print("Found a completed scan.")
            break

        model_random_seed = 42 if args.TEST else np.random.randint(0, high=999999999999)
        fit_random_seed = 42 if args.TEST else np.random.randint(0, high=999999999999)

        if args.FLOW == "maf":
            likelihood: Likelihood = masked_autoregressive_flow(
                dim=dim,
                transformer=transformer,
                flow_layers=args.FLOWLAYERS,
                nn_width=args.NNWIDTH,
                nn_depth=args.NNDEPTH,
                activation=args.ACTIVATION,
                permutation=args.PERMUTATION,
                random_seed=model_random_seed,
            )
        else:
            likelihood: Likelihood = coupling_flow(
                dim=dim,
                transformer=transformer,
                flow_layers=args.FLOWLAYERS,
                nn_width=[args.NNWIDTH] * args.NNDEPTH,
                activation=args.ACTIVATION,
                permutation=args.PERMUTATION,
                random_seed=model_random_seed,
            )

        scheduler = optax.exponential_decay(
            init_value=args.LR,
            transition_steps=args.TRANSSTEP,
            decay_rate=args.DECAYRATE,
            staircase=True,
            end_value=args.MINLR,
        )

        # tag the results with current time and date
        log_date = datetime.now().strftime("%b%d-%H-%M-%S")
        log_name = os.path.join(args.OUTPATH, f"log/fit_{log_date}")
        _ = likelihood.fit_to_data(
            dataset=X_train,
            L1_regularisation_coef=args.L1,
            L2_regularisation_coef=args.L2,
            learning_rate=args.LR,
            optimizer="adam",
            lr_scheduler=scheduler,
            max_epochs=args.EPOCHS,
            max_patience=args.MAXPATIENCE,
            batch_size=args.BATCH,
            random_seed=fit_random_seed,
            log=log_name,
        )

        hist = likelihood.goodness_of_fit(
            test_dataset=data["X_test"], prob_per_bin=args.PROB
        )
        chi2_pval = hist.residuals_pvalue
        kst_pval = hist.kstest_pval
        pval = min([chi2_pval, kst_pval])
        goodness_of_fit_test = pval >= args.PVALTHRES

        # Fill the tensorboard report with the results
        report = f"p(χ²) = {chi2_pval*100:.3g}%, p(KS) = {kst_pval*100:.3g}%"
        logger = SummaryWriter(log_name)
        logger.text(
            "Report",
            report + "<p>Residuals: " + ", ".join(f"{x:.3g}" for x in hist.pull) + "</p>",
        )
        print("   * " + report)
        print(f"   * Log recorded in: {log_name}")

        fig, _ = summary_plot(
            likelihood=likelihood,
            test_data=data["X_test"],
            prob_per_bin=args.PROB,
        )
        logger.figure("Summary-plot", fig)
        close(fig)

        if goodness_of_fit_test or args.TEST:
            likelihood.transform = p_trans
            likelihood.save(os.path.join(args.OUTPATH, "model.nabu"))
            np.savez_compressed(
                os.path.join(args.OUTPATH, "deviations.npz"),
                deviations=likelihood.compute_inverse(data["X_test"]),
            )
            # Save random seeds
            with open(
                os.path.join(args.OUTPATH, "config.yaml"), "a", encoding="utf-8"
            ) as f:
                yaml.safe_dump(
                    {
                        "model_random_seed": int(model_random_seed),
                        "fit_random_seed": int(fit_random_seed),
                        "chi2_pval": float(chi2_pval),
                        "kst_pval": float(kst_pval),
                    },
                    f,
                )
            os.rename(args.OUTPATH, args.OUTPATH + "-COMPLETE")
            args.OUTPATH += "-COMPLETE"
            break
    print(" * Files are saved in", args.OUTPATH)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Search model space for proper nflow")

    parameters = parser.add_argument_group("Hyperparameters")
    parameters.add_argument(
        "--flow-type",
        "-ft",
        type=str,
        default="maf",
        choices=["maf", "coupling"],
        help="Type of flow, defaults to maf.",
        dest="FLOW",
    )
    parameters.add_argument(
        "--nn-width",
        "-w",
        type=int,
        default=32,
        help="Number of nodes per layer",
        dest="NNWIDTH",
    )
    parameters.add_argument(
        "--nn-depth",
        "-d",
        type=int,
        default=3,
        help="Number of layers",
        dest="NNDEPTH",
    )
    parameters.add_argument(
        "--flow-layers",
        "-l",
        type=int,
        default=2,
        help="Number of flow layers",
        dest="FLOWLAYERS",
    )
    parameters.add_argument(
        "--activation",
        "-a",
        type=str,
        default="relu",
        choices=available_activations(),
        help="Activation function for MLP.",
        dest="ACTIVATION",
    )
    parameters.add_argument(
        "--permutation",
        "-perm",
        type=str,
        default="reversed",
        choices=["reversed", "random"],
        help="Permutation for the flow, defaults to reversed.",
        dest="PERMUTATION",
    )
    parameters.add_argument(
        "-lr", type=float, default=1e-3, help="learning rate", dest="LR"
    )
    parameters.add_argument(
        "-mlr", type=float, default=1e-4, help="minlearning rate", dest="MINLR"
    )
    parameters.add_argument(
        "--l1-regularisation",
        "-l1",
        type=float,
        default=0.0,
        help="L1 Regularisation coefficient, defaults to 0.",
        dest="L1",
    )
    parameters.add_argument(
        "--l2-regularisation",
        "-l2",
        type=float,
        default=0.0,
        help="L2 Regularisation coefficient, defaults to 0.",
        dest="L2",
    )
    parameters.add_argument(
        "--max-patience",
        "-mp",
        type=int,
        default=50,
        help="Number of epochs for training to monitor validation "
        "loss before terminating. Defaults to 50.",
        dest="MAXPATIENCE",
    )
    parameters.add_argument(
        "--decay-rate",
        "-dr",
        type=float,
        default=0.5,
        help="Decay rate for LR scheduler, defaults to 0.5",
        dest="DECAYRATE",
    )
    parameters.add_argument(
        "--transition-steps",
        "-ts",
        type=int,
        default=25,
        help="Number of epochs for LR scheduler to wait before decaying, "
        "defaults to 25",
        dest="TRANSSTEP",
    )
    parameters.add_argument(
        "--epochs",
        "-e",
        type=int,
        default=600,
        help="Number of epochs, defaults to 600",
        dest="EPOCHS",
    )
    parameters.add_argument(
        "--batch-size",
        "-bs",
        type=int,
        default=100,
        help="Batch size, defaults to 100",
        dest="BATCH",
    )
    parameters.add_argument(
        "--pval-threshold",
        "-pt",
        type=float,
        default=0.03,
        help="Threshold for 1-CDF of the residuals, defaults to >= 0.03",
        dest="PVALTHRES",
    )
    parameters.add_argument(
        "--prob-per-bin",
        "-p",
        type=float,
        default=0.1,
        help="Yield probability per bin for goodness of fit test, defaults to 0.1",
        dest="PROB",
    )

    ansatz = parser.add_argument_group("Ansatz")
    ansatz.add_argument(
        "--transformer",
        "-t",
        type=str,
        default="affine",
        choices=["affine", "rqs"],
        help="Choices of transformers to be used in MAF, defaults to affine.",
        dest="TRANS",
    )
    ansatz.add_argument(
        "--knots",
        "-k",
        type=int,
        default=6,
        help="Knots for RQS only, defaults to 6",
        dest="KNOTS",
    )
    ansatz.add_argument(
        "--interval",
        "-int",
        type=float,
        nargs="+",
        default=[4.0],
        help="Interval for RQS only, defaults to 4. "
        "If two value is given, will be considered as the lower and upper bounds.",
        dest="INTERVAL",
    )

    dataset = parser.add_argument_group("Dataset")
    dataset.add_argument(
        "--data-path",
        "-dp",
        type=str,
        help="NumPy file includes standardised dataset. "
        "File should include 'X_train` and `X_test` keyword arguments."
        "Use the following to save your dataset `np.savez_compressed('my_dataset.npz', X_train=X_train, X_test=X_test)`. "
        "Notice that `X_train` corresponds to the samples that will be used for training and "
        "`X_test` corresponds to the data that will be used for testing. If --standardisation-config is used"
        "data set is expected to be in `dataset` keyword within compressed numpy file.",
        dest="DATAPATH",
        default="__no-file__.npz",
    )
    dataset.add_argument(
        "--train-frac",
        "-tfrac",
        type=float,
        default=1.0,
        help="Fraction of the data to be used as the training set.",
        dest="DATAFRAC",
    )
    dataset.add_argument(
        "--write-configuration",
        type=str,
        default=None,
        choices=list(standardisation_methods.keys()),
        help="Write configuration file for the given standardisation method and exit.",
        dest="WRITESTDCONF",
    )
    dataset.add_argument(
        "--standardisation-config",
        "-sconf",
        type=str,
        default=None,
        help="Use standardisation configuration to standardise the data before training.",
        dest="STDCONFIG",
    )

    paths = parser.add_argument_group("Paths")
    paths.add_argument(
        "--out-path",
        "-op",
        type=str,
        help="Output path, defaults to " + os.path.join(os.getcwd(), "results"),
        dest="OUTPATH",
        default=os.path.join(os.getcwd(), "results"),
    )
    paths.add_argument(
        "--out-name",
        "-on",
        type=str,
        help="Output name, default to " + datetime.now().strftime("%b%d-%H-%M-%S"),
        dest="OUTNAME",
        default=datetime.now().strftime("%b%d-%H-%M-%S"),
    )

    execution = parser.add_argument_group("Execution")
    execution.add_argument(
        "-gpu",
        action="store_true",
        default=False,
        help="Run on GPU, defaults to False",
        dest="GPU",
    )
    execution.add_argument(
        "-test",
        action="store_true",
        default=False,
        help="Run on the Test routine. This routine fixes the random "
        "seeds and executes the algorithm once.",
        dest="TEST",
    )

    args = parser.parse_args()

    if args.WRITESTDCONF is not None:
        conf = inspect.getfullargspec(standardisation_methods[args.WRITESTDCONF])
        configuration = {"method": args.WRITESTDCONF, "train_frac": 0.8}
        if len(conf.args) > 1:
            configuration.update({"keywordargs": dict(zip(conf.args[1:], conf.defaults))})
        with open(
            os.path.join(os.getcwd(), "standardisation.yaml"), "w", encoding="utf-8"
        ) as f:
            yaml.safe_dump(configuration, f)
        sys.exit(0)

    assert len(args.INTERVAL) <= 2, "RQS Interval should be a single value or a pair."

    if not os.path.isdir(args.OUTPATH):
        os.mkdir(args.OUTPATH)

    args.OUTNAME = "TEST-RESULT" if args.TEST else args.OUTNAME

    args.OUTPATH = os.path.join(args.OUTPATH, args.OUTNAME)
    if not os.path.isdir(args.OUTPATH):
        os.mkdir(args.OUTPATH)

    # check data
    data_path = Path(args.DATAPATH)
    if not data_path.is_file():
        raise ValueError(f"Can not find dataset: {data_path}")
    if data_path.suffix != ".npz":
        raise ValueError(
            "Input file needs to be a compressed NumPy file with .npz extension,"
            f" {data_path.suffix} is given."
        )

    print("<><><> Arguments <><><>")
    for key, item in vars(args).items():
        print(f"   * {key} : {item}")
    print("<><><><><><><><><><><>")

    arg_dict = vars(args)
    arg_dict.update({"HOSTNAME": os.environ.get("HOSTNAME", None)})
    with open(os.path.join(args.OUTPATH, "config.yaml"), "w", encoding="utf-8") as f:
        yaml.safe_dump(arg_dict, f)

    jax.config.update("jax_platform_name", "gpu" if args.GPU else "cpu")
    main(args)
