#!python

import os
import pandas as pd
import argparse
import sys
from ase.io import read

import logging

LOG_FMT = "%(asctime)s %(levelname).1s - %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
logger = logging.getLogger()


def sizeof_fmt(file_name_or_size, suffix="B"):
    if isinstance(file_name_or_size, str):
        file_name_or_size = os.path.getsize(file_name_or_size)
    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
        if abs(file_name_or_size) < 1024.0:
            return "%3.1f%s%s" % (file_name_or_size, unit, suffix)
        file_name_or_size /= 1024.0
    return "%.1f%s%s" % (file_name_or_size, "Yi", suffix)


def main(args):
    ##############################################################################################
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "extxyz_filename",
        help="Name of extxyz file",
        type=str,
    )

    parser.add_argument(
        "--output-dataset-filename",
        help="pickle filename, default is inferenced from extxyz_filename",
        type=str,
        default=None,
    )

    args_parse = parser.parse_args(args)
    extxyz_filename = os.path.abspath(args_parse.extxyz_filename)
    output_dataset_filename = args_parse.output_dataset_filename
    if output_dataset_filename is None:
        output_dataset_filename = extxyz_filename.replace(".extxyz", ".pkl.gz")
        output_dataset_filename = extxyz_filename.replace(".xyz", ".pkl.gz")

    logging.info(f"Input filename: {extxyz_filename}")
    logging.info(f"Output filename: {output_dataset_filename}")

    logging.info(
        f"Reading extxyz file: {extxyz_filename} ({sizeof_fmt(extxyz_filename)})"
    )
    data = read(extxyz_filename, format="extxyz", index=":")
    logging.info(f"{len(data)} structures read from {extxyz_filename}")
    logging.info("Converting to dataframe")
    df = pd.DataFrame({"ase_atoms": data})
    logging.info("Extracting energy")
    df["energy"] = df["ase_atoms"].map(
        lambda at: at.get_potential_energy(force_consistent=True)
    )
    logging.info("Extracting forces")
    df["forces"] = df["ase_atoms"].map(lambda at: at.get_forces())

    try:
        logging.info("Trying to extract stresses")
        df["stress"] = df["ase_atoms"].map(lambda at: at.get_stress())
    except Exception as e:
        logger.error(f"Error while trying to extract stress: {e}, ignoring")

    logging.info("Extracting info")
    df["info"] = df["ase_atoms"].map(lambda at: at.info)

    logging.info(f"Saving dataframe")
    df.to_pickle(output_dataset_filename, compression="gzip")
    logging.info(
        f"Saved dataframe to {output_dataset_filename} ({sizeof_fmt(output_dataset_filename)})"
    )


if __name__ == "__main__":
    main(sys.argv[1:])
