#!/usr/bin/env python
import os
import sys
import argparse


def list_models(_):

    from tensorpotential.calculator.foundation_models import (
        MODELS_METADATA,
        FOUNDATION_CACHE_DIR,
    )

    """List available models from MODELS_METADATA."""
    print("Available models:")
    print("=" * 80)

    for model, model_data in MODELS_METADATA.items():
        model_path = os.path.join(FOUNDATION_CACHE_DIR, model_data["dirname"])
        if os.path.isdir(model_path):
            msg = f"{model}: {model_path}"
        else:
            msg = f"{model}: [NOT DOWNLOADED]"
        if "description" in model_data:
            msg += "\n\t" + model_data["description"].replace("\n", "\n\t")

        print(msg)
        print(f"LICENSE: {MODELS_METADATA[model].get('license', 'not provided')}")
        print("=" * 80)


def download_model(args):
    """Download the specified model using the download_fm function."""

    from tensorpotential.calculator.foundation_models import (
        download_fm,
        MODELS_METADATA,
        MODELS_NAME_LIST,
    )

    model_name = args.model_name

    if model_name in MODELS_METADATA:
        print(f"Downloading model: {model_name}")
        download_fm(model_name)
        print(f"Model {model_name} downloaded successfully.")
    elif model_name in ["all"]:
        print(f"Downloading ALL models: {MODELS_NAME_LIST}")
        for model_name in MODELS_NAME_LIST:
            print(f"Downloading model: {model_name}")
            download_fm(model_name)
            print(f"Model {model_name} downloaded successfully.")
        print(f"All models were downloaded")
    else:
        print(f"Model {model_name} not found in available models.")


def build_parser():
    parser = argparse.ArgumentParser(
        prog="grace_models", description="Download foundation GRACE models"
    )

    subparsers = parser.add_subparsers(dest="command", help="Sub-command help")

    # Sub-command: list
    parser_list = subparsers.add_parser("list", help="List available models")
    parser_list.set_defaults(func=list_models)

    # Sub-command: download
    parser_download = subparsers.add_parser("download", help="Download a model")
    parser_download.add_argument(
        "model_name", type=str, help="Name of the model to download"
    )
    parser_download.set_defaults(func=download_model)

    return parser


def main(args):
    parser = build_parser()
    args_parse = parser.parse_args(args)

    if args_parse.command:
        args_parse.func(args_parse)
    else:
        parser.print_help()


if __name__ == "__main__":
    main(sys.argv[1:])
