#!/usr/bin/env python
import argparse
import logging
import sys

import numpy as np
from ase import Atoms


LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
logger = logging.getLogger()


def preprocess_args(args):
    if not args.output_suffix.startswith("-"):
        args.output_suffix = "-" + args.output_suffix

    if args.checkpoint_path and args.checkpoint_path.endswith(".index"):
        args.checkpoint_path = args.checkpoint_path.replace(".index", "")


def update_model(args):
    preprocess_args(args)

    """Convert model.yaml and corresponding checkpoint to model as dict-of-instructions and corresponding checkpoint"""
    from tensorpotential import TensorPotential
    from tensorpotential.instructions.base import (
        load_instructions,
        save_instructions_dict,
    )

    model_path, checkpoint_path, output_suffix = (
        args.potential,
        args.checkpoint_path,
        args.output_suffix,
    )

    # clean checkpoint_name from .index suffix if needed
    if checkpoint_path.endswith(".index"):
        checkpoint_path = checkpoint_path.replace(".index", "")

    logger.info(
        f"Converting model {model_path} with checkpoint at {checkpoint_path} to dict-like model"
    )

    logger.info(f"Loading model from {model_path}")
    instructions = load_instructions(model_path)
    if isinstance(instructions, dict):
        logger.info(f"Model in {model_path} is already in new format (dict-like)")
        sys.exit(0)

    # convert to dict-of-instructions
    instructions_dict = {ins.name: ins for ins in instructions}
    assert len(instructions_dict) == len(instructions)

    tp = TensorPotential(potential=instructions)
    logger.info(f"Loading checkpoint from {checkpoint_path}")
    tp.load_checkpoint(
        checkpoint_name=checkpoint_path,
        expect_partial=True,
        verbose=True,
        raise_errors=True,
    )

    tp.model.instructions = instructions_dict

    new_checkpoint_path = checkpoint_path + output_suffix
    logger.info(f"Saving converted checkpoint to {new_checkpoint_path}")
    tp.save_checkpoint(checkpoint_name=new_checkpoint_path, verbose=True)

    new_model_path = model_path
    for ext in [".yaml", ".yml"]:
        new_model_path = new_model_path.replace(ext, output_suffix + ext)

    logger.info(f"Saving converted model to {new_model_path}")
    save_instructions_dict(new_model_path, instructions_dict)

    # try to load again
    logger.info(f"Trying to load converted model from {new_model_path}")
    new_instructions = load_instructions(new_model_path)
    tp_new = TensorPotential(potential=new_instructions)
    logger.info(f"Trying to load checkpoint from {new_checkpoint_path}")
    tp_new.load_checkpoint(
        checkpoint_name=new_checkpoint_path,
        expect_partial=True,
        verbose=True,
        raise_errors=True,
    )

    logger.info(f"Finished")


def reduce_elements(args):

    preprocess_args(args)

    from tensorpotential import TensorPotential
    from tensorpotential.instructions.base import load_instructions
    from tensorpotential.calculator.asecalculator import TPCalculator
    from tensorpotential.utils import convert_model_reduce_elements

    model_path, checkpoint_path, output_suffix, elements = (
        args.potential,
        args.checkpoint_path,
        args.output_suffix,
        args.elements,
    )

    # clean checkpoint_name from .index suffix if needed
    if checkpoint_path.endswith(".index"):
        checkpoint_path = checkpoint_path.replace(".index", "")

    logger.info(
        f"Selecting {elements} elements from model {model_path}, with checkpoint {checkpoint_path}"
    )

    logger.info(f"Loading model from {model_path}")
    instructions_dict = load_instructions(model_path)
    if not isinstance(instructions_dict, dict):
        logger.info(
            f"Model in {model_path} is NOT in new format (dict-like). Convert it using `update_model` option"
        )
        sys.exit(0)

    tp = TensorPotential(potential=instructions_dict)
    logger.info(f"Loading checkpoint from {checkpoint_path}")
    tp.load_checkpoint(
        checkpoint_name=checkpoint_path,
        expect_partial=True,
        verbose=True,
        raise_errors=True,
    )

    new_checkpoint_path = checkpoint_path + output_suffix
    new_model_path = model_path
    for ext in [".yaml", ".yml"]:
        new_model_path = new_model_path.replace(ext, output_suffix + ext)

    # generate test structures
    test_atoms = {}
    calc = TPCalculator(model=tp.model)
    for el in elements:
        at = Atoms(positions=[[0, 0, 0], [2, 0, 0]], symbols=[el, el], pbc=False)
        at.calc = calc
        en = at.get_potential_energy()
        logger.info(f"{el} dimer: energy = {en}")
        test_atoms[el] = {"atoms": at, "e": en}

    convert_model_reduce_elements(
        elements,
        potential_file_name=model_path,
        checkpoint_name=checkpoint_path,
        new_potential_file_name=new_model_path,
        new_checkpoint_name=new_checkpoint_path,
    )

    # try to load again
    logger.info("==========================")
    logger.info(f"Trying to load converted model from {new_model_path}")
    new_instructions = load_instructions(new_model_path)
    tp_new = TensorPotential(potential=new_instructions)
    logger.info(f"Trying to load checkpoint from {new_checkpoint_path}")
    tp_new.load_checkpoint(
        checkpoint_name=new_checkpoint_path,
        expect_partial=True,
        verbose=True,
        raise_errors=True,
    )

    calc_new = TPCalculator(model=tp_new.model)
    for el, info in test_atoms.items():
        at = info["atoms"]
        en_ref = info["e"]
        at.calc = calc_new
        en = at.get_potential_energy()
        if np.allclose(en, en_ref):
            diff = en - en_ref
            logger.info(f"{el} dimer: OK, diff = {diff:.5g} eV")
        assert np.allclose(
            en, en_ref
        ), f"Energy not equal for element {el}: {en} != {en_ref}"

    logger.info(f"Finished")


def cast_model(args):
    from tensorpotential import TensorPotential
    from tensorpotential.instructions.base import load_instructions
    from tensorflow import float32, float64, Variable, Module

    def cast_instruction_vars(instruction, from_dtype, to_dtype):
        for att_name in instruction.__dict__:
            attr_value = getattr(instruction, att_name)
            if isinstance(attr_value, Variable) and attr_value.dtype == from_dtype:
                setattr(
                    instruction,
                    att_name,
                    Variable(
                        attr_value.numpy(),
                        dtype=to_dtype,
                        name=attr_value.name.replace(":0", ""),
                    ),
                )
            if isinstance(attr_value, Module):  # Check nested objects
                cast_instruction_vars(attr_value, from_dtype, to_dtype)

    preprocess_args(args)
    dtypes = {"fp32": float32, "fp64": float64}

    model_path, checkpoint_path, output_suffix, from_dtype, to_dtype = (
        args.potential,
        args.checkpoint_path,
        args.output_suffix,
        args.curr,
        args.to,
    )
    assert from_dtype in dtypes, f"Invalid dtype, {from_dtype} not in {dtypes.keys()}"
    assert to_dtype in dtypes, f"Invalid dtype, {to_dtype} not in {dtypes.keys()}"
    logger.info(
        f"Trying to cast model at model_path={model_path},"
        f" with checkpoint_path={checkpoint_path},"
        f" from dtype={from_dtype} to dtype={to_dtype}"
    )

    d_from = dtypes[from_dtype]
    d_to = dtypes[to_dtype]

    pot = load_instructions(model_path)
    tp = TensorPotential(potential=pot, fit_config={}, float_dtype=d_from)
    tp.load_checkpoint(
        checkpoint_name=checkpoint_path, expect_partial=True, verbose=True
    )
    assert isinstance(
        tp.model.instructions, dict
    ), "Model is not dict format, cannot cast it. Use `update_model` option to convert it to dict format."
    for iname, ins in tp.model.instructions.items():
        cast_instruction_vars(ins, d_from, d_to)
    tp.save_checkpoint(checkpoint_name=checkpoint_path + output_suffix, verbose=True)
    logger.info(f"Successfully casted model at {checkpoint_path} to {to_dtype} dtype")


def export_model(args):
    """Convert model.yaml and corresponding checkpoint to model as dict-of-instructions and corresponding checkpoint"""
    from tensorpotential import TensorPotential
    from tensorpotential.instructions.base import load_instructions

    preprocess_args(args)

    model_path, checkpoint_path, output_suffix = (
        args.potential,
        args.checkpoint_path,
        args.output_suffix,
    )
    save_fs = args.sf

    # clean checkpoint_name from .index suffix if needed
    if checkpoint_path.endswith(".index"):
        checkpoint_path = checkpoint_path.replace(".index", "")

    logger.info(
        f"Converting model {model_path} with checkpoint at {checkpoint_path} to dict-like model"
    )

    logger.info(f"Loading model from {model_path}")
    instructions = load_instructions(model_path)
    tp = TensorPotential(potential=instructions)
    logger.info(f"Loading checkpoint from {checkpoint_path}")
    tp.load_checkpoint(
        checkpoint_name=checkpoint_path,
        expect_partial=True,
        verbose=True,
        raise_errors=True,
    )

    if not save_fs:
        saved_model_name = args.saved_model_name
        logger.info(f"Exporting model to {saved_model_name}")
        tp.save_model(exact_path=saved_model_name, jit_compile=True)
    else:
        saved_model_name = args.saved_model_name
        if not saved_model_name.endswith(".yaml") or saved_model_name.endswith(".yml"):
            saved_model_name += ".yaml"
        logger.info(f"Exporting FS model to {saved_model_name}")
        tp.export_to_yaml(exact_filename=saved_model_name)

    logger.info(f"Finished")


def model_summary(args):
    preprocess_args(args)

    from tensorpotential.instructions.base import load_instructions
    from tensorpotential.tpmodel import TPModel
    import tensorflow as tf

    model_path = args.potential
    verbose = args.verbose

    logger.info(f"Loading model from {model_path}")
    instructions_dict = load_instructions(model_path)
    tp = TPModel(instructions_dict)
    tp.build(tf.float64)
    res = tp.summary(verbose=verbose)
    logger.info(f"Summary (verbose={verbose}) is:\n{res}")


def main():
    parser = argparse.ArgumentParser(
        prog="grace_utils",
        description="CLI tool for model conversions and summarization",
    )
    parser.add_argument("-p", "--potential", required=True, help="Path to model.yaml")
    parser.add_argument(
        "-c", "--checkpoint-path", required=False, help="Path to checkpoint"
    )
    parser.add_argument(
        "-os",
        "--output-suffix",
        type=str,
        default="converted",
        help="Output suffix for converted",
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    # update_models
    parser_update = subparsers.add_parser(
        "update_model", help="Update model (model.yaml) and corresponding checkpoint."
    )
    parser_update.set_defaults(func=lambda args: update_model(args))

    # select_elements
    parser_reduce = subparsers.add_parser(
        "reduce_elements", help="Reduce elements from the model."
    )
    parser_reduce.add_argument(
        "-e", "--elements", nargs="+", required=True, help="Elements to select"
    )
    parser_reduce.set_defaults(func=lambda args: reduce_elements(args))

    # precision
    parser_precision = subparsers.add_parser(
        "cast_model", help="Change model's floating point precision."
    )
    parser_precision.add_argument(
        "-curr",
        required=True,
        choices=["fp32", "fp64"],
        help="Current precision type to cast from",
    )
    parser_precision.add_argument(
        "-to",
        required=True,
        choices=["fp32", "fp64"],
        help="New precision type to cast into",
    )
    parser_precision.set_defaults(func=lambda args: cast_model(args))

    # export model
    parser_export = subparsers.add_parser(
        "export", help="Export model to saved_model or FS/C++ format."
    )
    parser_export.add_argument(
        "-sf",
        action="store_true",
        default=False,
        help="Save to GRACE-FS/C++ YAML model format",
    )
    parser_export.add_argument(
        "-n",
        "--saved-model-name",
        dest="saved_model_name",
        type=str,
        default="saved_model",
        help="Save to GRACE-FS/C++ YAML model format",
    )
    parser_export.set_defaults(func=lambda args: export_model(args))

    # info
    parser_summary = subparsers.add_parser("summary", help="Show info about the model")
    parser_summary.add_argument(
        "-v",
        "--verbose",
        default=0,
        type=int,
        choices=[0, 1, 2],
        help="Verbosity level: 0, 1 or 2",
    )

    parser_summary.set_defaults(func=lambda args: model_summary(args))

    # Parse arguments and execute the corresponding function
    args = parser.parse_args()
    args.func(args)


if __name__ == "__main__":
    main()
