#!python
import argparse
import logging
import os
import sys

import pandas as pd
from ase.calculators.singlepoint import SinglePointCalculator

LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
logger = logging.getLogger()


from ase.io import write


def build_parser():
    parser = argparse.ArgumentParser(
        prog="df2extxyz", description="Conversion from df.pkl.gz to extxyz"
    )

    parser.add_argument("input", help="input pkl.gz file", type=str)

    parser.add_argument(
        "-e",
        "--energy-column",
        help="name of energy column",
        type=str,
        default="energy_corrected",
    )

    parser.add_argument(
        "-f", "--force-column", help="name of forces column", type=str, default="forces"
    )

    parser.add_argument(
        "-s",
        "--stress-column",
        help="name of stress column",
        type=str,
        default="stress",
    )

    parser.add_argument(
        "-o", "--output", help="output file name", type=str, default=None
    )
    return parser


def sizeof_fmt(file_name_or_size, suffix="B"):
    if isinstance(file_name_or_size, str):
        file_name_or_size = os.path.getsize(file_name_or_size)
    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
        if abs(file_name_or_size) < 1024.0:
            return "%3.1f%s%s" % (file_name_or_size, unit, suffix)
        file_name_or_size /= 1024.0
    return "%.1f%s%s" % (file_name_or_size, "Yi", suffix)


def main(args):
    parser = build_parser()
    args_parse = parser.parse_args(args)
    input_fname = args_parse.input
    energy_col = args_parse.energy_column
    force_col = args_parse.force_column
    stress_col = args_parse.stress_column

    logging.info(f"Reading input file : {input_fname} ({sizeof_fmt(input_fname)})")
    df = pd.read_pickle(input_fname, compression="gzip")
    logging.info(f"Dataframe shape: {df.shape}")
    df.reset_index(drop=True, inplace=True)

    atoms_list = []
    for _, row in df.iterrows():
        at = row["ase_atoms"]
        stress = row[stress_col] if stress_col in row else None
        at.info.update(
            {
                "REF_energy": row[energy_col],
            }
        )
        at.arrays.update({"REF_forces": row[force_col]})
        if stress_col in row:
            at.info["REF_stress"] = row[stress_col]

        at.calc = SinglePointCalculator(
            at,
            energy=row[energy_col],
            forces=row[force_col],
            stress=stress,
        )

        atoms_list.append(at)

    output_fname = args_parse.output
    if output_fname is None:
        ext_to_replace = [".pckl.gzip", "pckl.gz", ".pkl.gzip", ".pkl.gz"]
        output_fname = input_fname
        for ext in ext_to_replace:
            output_fname = output_fname.replace(ext, ".xyz")
    logging.info(f"Writing to output filename: {output_fname}")
    write(output_fname, atoms_list, format="extxyz")
    logging.info(f"Saved to {output_fname} ({sizeof_fmt(output_fname)})")


if __name__ == "__main__":
    main(sys.argv[1:])
