#!python

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import sys
import os
import pandas as pd
import argparse

from tensorpotential.calculator import TPCalculator

import logging

LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
logger = logging.getLogger()

from tqdm import tqdm

tqdm.pandas()


def predict(row, calc):
    at = row["ase_atoms"].copy()
    if "mag_mom" in row:
        at.set_initial_magnetic_moments(row["mag_mom"])
    at.calc = calc
    e = at.get_potential_energy()
    f = at.get_forces()
    return {"energy": e, "forces": f}


def main(args):
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-m",
        "--model_path",
        help="provide path to the saved_model directory",
        type=str,
        default="saved_model",
        dest="model_path",
    )

    parser.add_argument(
        "-d",
        "--dataset",
        help="path to the dataset.pkl.gzip containing ase_atoms structures",
        type=str,
        default="dataset.pkl.gz",
        dest="dataset_file",
    )

    parser.add_argument(
        "-o",
        "--output",
        help="path to the OUTPUT dataset (pkl.gzip) containing energy_predicted and forces_predicted",
        type=str,
        default="predicted_dataset.pkl.gz",
        dest="output",
    )

    args_parse = parser.parse_args(args)

    model_path = os.path.abspath(args_parse.model_path)
    dataset_file = args_parse.dataset_file
    output_file = args_parse.output

    logger.info(f"Loading model from: {model_path}")
    calc = TPCalculator(
        model=model_path,
        pad_atoms_number=20,
        pad_neighbors_fraction=0.30,
        # max_number_reduction_recompilation=3,
    )

    logger.info(f"Loading dataset from: {dataset_file}")
    df = pd.read_pickle(dataset_file, compression="gzip")

    logger.info(f"Starting prediction")

    df["prediction"] = df.progress_apply(predict, axis=1, args=(calc,))
    df["energy_predicted"] = df["prediction"].map(lambda x: x["energy"])
    df["forces_predicted"] = df["prediction"].map(lambda x: x["forces"])
    df = df.drop(columns=["ase_atoms", "prediction"])

    logger.info(f"Saving dataset to {output_file}")
    df.drop(columns=["ase_atoms"]).to_pickle(output_file, compression="gzip")


if __name__ == "__main__":
    main(sys.argv[1:])
