#!/usr/bin/env python
import os
import sys
import argparse


def list_models(_):

    from tensorpotential.calculator.foundation_models import (
        MODELS_METADATA,
        FOUNDATION_CACHE_DIR,
        FOUNDATION_CHECKPOINTS_CACHE_DIR,
        LICENSE_KEY,
        CHECKPOINT_URL_KEY,
        CHECKPOINT_PATH_KEY,
        MODEL_PATH_KEY,
        DESCRIPTION_KEY,
    )

    """List available models from MODELS_METADATA."""
    print("Available models:")
    print("=" * 80)

    for model, model_data in MODELS_METADATA.items():
        # saved model
        model_path = model_data.get(MODEL_PATH_KEY) or os.path.join(
            FOUNDATION_CACHE_DIR, model
        )
        msg = f"{model}"
        if DESCRIPTION_KEY in model_data:
            msg += "\n\tDESCRIPTION: " + model_data["description"].replace("\n", "\n\t")
        if os.path.isdir(model_path):
            msg += f"\n\tPATH: {model_path}"
        else:
            msg += f"\n\tPATH: [NOT DOWNLOADED]"

        # checkpoint
        if CHECKPOINT_URL_KEY in model_data:
            checkpoint_path = model_data.get(CHECKPOINT_PATH_KEY) or os.path.join(
                FOUNDATION_CHECKPOINTS_CACHE_DIR, model
            )
            if os.path.isdir(checkpoint_path):
                msg += f"\n\tCHECKPOINT: {checkpoint_path}"
            else:
                msg += f"\n\tCHECKPOINT: AVAILABLE, BUT NOT DOWNLOADED"

        print(msg)
        print(f"\tLICENSE: {MODELS_METADATA[model].get(LICENSE_KEY, 'not provided')}")
        print("=" * 80)


def download_model(args):
    """Download the specified model using the get_or_download_model function."""

    from tensorpotential.calculator.foundation_models import (
        get_or_download_model,
        MODELS_METADATA,
        MODELS_NAME_LIST,
    )

    model_name = args.model_name

    if model_name in MODELS_METADATA:
        print(f"Downloading model: {model_name}")
        get_or_download_model(model_name)
        print(f"Model {model_name} downloaded successfully.")
    elif model_name in ["all"]:
        print(f"Downloading ALL models: {MODELS_NAME_LIST}")
        for model_name in MODELS_NAME_LIST:
            print(f"Downloading model: {model_name}")
            get_or_download_model(model_name)
            print(f"Model {model_name} downloaded successfully.")
    else:
        print(f"Model {model_name} not found in available models.")


def download_checkpoint(args):
    """Download the specified model's checkpoint"""

    from tensorpotential.calculator.foundation_models import (
        get_or_download_checkpoint,
        MODELS_METADATA,
        MODELS_NAME_LIST,
    )

    model_name = args.model_name

    if model_name in MODELS_METADATA:
        print(f"Downloading checkpoint for {model_name}")
        get_or_download_checkpoint(model_name)
        print(f"Checkpoint for {model_name} downloaded successfully.")
    elif model_name in ["all"]:
        print(f"Downloading ALL models: {MODELS_NAME_LIST}")
        for model_name in MODELS_NAME_LIST:
            print(f"Downloading model: {model_name}")
            get_or_download_checkpoint(model_name)
            print(f"Model {model_name} downloaded successfully.")
    else:
        print(f"Model {model_name} not found in available models.")


def build_parser():
    parser = argparse.ArgumentParser(
        prog="grace_models", description="Download foundation GRACE models"
    )

    subparsers = parser.add_subparsers(dest="command", help="Sub-command help")

    # Sub-command: list
    parser_list = subparsers.add_parser("list", help="List available models")
    parser_list.set_defaults(func=list_models)

    # Sub-command: download
    parser_download = subparsers.add_parser("download", help="Download a model")
    parser_download.add_argument(
        "model_name", type=str, help="Name of the model to download"
    )
    parser_download.set_defaults(func=download_model)

    # Sub-command: checkpoint
    parser_checkpoint = subparsers.add_parser(
        "checkpoint", help="Download a checkpoint"
    )
    parser_checkpoint.add_argument(
        "model_name", type=str, help="Name of the model to download checkpoint for"
    )
    parser_checkpoint.set_defaults(func=download_checkpoint)

    return parser


def main(args):
    parser = build_parser()
    args_parse = parser.parse_args(args)

    if args_parse.command:
        args_parse.func(args_parse)
    else:
        parser.print_help()


if __name__ == "__main__":
    main(sys.argv[1:])
