#!python

import hydra
import os
import transformers
from pathlib import Path
from omegaconf import OmegaConf
from hydra.core.hydra_config import HydraConfig

import logging

log = logging.getLogger("lm_polygraph")
logging.getLogger("httpx").setLevel(logging.WARNING)

from lm_polygraph.utils.manager import UEManager
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.model import WhiteboxModel, BlackboxModel
from lm_polygraph.model_adapters import WhiteboxModelvLLM
from lm_polygraph.model_adapters.visual_whitebox_model import VisualWhiteboxModel
from lm_polygraph.utils.processor import Logger
from lm_polygraph.generation_metrics import *
from lm_polygraph.estimators import *
from lm_polygraph.ue_metrics import *
from lm_polygraph.utils.common import load_external_module, load_processor, load_image
from lm_polygraph.utils.generation_parameters import GenerationParameters, GenerationParametersFactory
from lm_polygraph.defaults.register_default_stat_calculators import (
    register_default_stat_calculators,
)
from lm_polygraph.utils.builder_enviroment_stat_calculator import (
    BuilderEnvironmentStatCalculator,
)
from lm_polygraph.utils.factory_estimator import FactoryEstimator
from lm_polygraph.utils.factory_stat_calculator import StatCalculatorContainer
#from transformers import AutoProcessor, AutoModelForVision2Seq

hydra_config = Path(os.environ.get("HYDRA_CONFIG", ""))


def get_hydra_config_path():
    config = HydraConfig.get()
    # Iterate over the list of config sources in the runtime section
    for source in config.get("runtime", {}).get("config_sources", []):
        if source.get("provider") == "command-line":
            return Path(source.get("path")) / config.job.config_name
    return None


def wandb_save_directory(directory_path):
    import wandb

    for file_name in os.listdir(directory_path):
        full_path = os.path.join(directory_path, file_name)
        if os.path.isfile(full_path):  # Make sure it's a file, not a directory
            wandb.save(full_path)


@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def main(args):
    save_path = HydraConfig.get().runtime.output_dir
    os.chdir(hydra.utils.get_original_cwd())

    global hydra_config
    if hydra_config == Path("."):
        hydra_config = get_hydra_config_path()

    if hasattr(args, "report_to_wandb") and args.report_to_wandb:
        import wandb

        wandb_cfg = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)
        config_path_hydra = [
            path["path"]
            for path in HydraConfig.get().runtime.config_sources
            if path["schema"] == "file"
        ][0]
        wandb_cfg["HYDRA_CONFIG"] = (
            Path(config_path_hydra) / HydraConfig.get().job.config_name
        )
        os.environ["WANDB_DIR"] = str(Path(save_path))
        project = os.environ["WANDB_PROJECT"]
        wandb.init(project=project, dir=save_path, config=wandb_cfg)
        wandb_save_directory(Path(save_path) / ".hydra")
    
    save_path = args.save_path if "save_path" in args else save_path
    log.info(f"Main directory: {save_path}")

    if args.seed is None or len(args.seed) == 0:
        args.seed = [1]

    cache_kwargs = {}
    if os.environ.get("HF_DATASETS_OFFLINE", "").strip() == "1":
        cache_kwargs = {"cache_dir": args.cache_path}

    for seed in args.seed:
        log.info("=" * 100)
        log.info(f"SEED: {seed}")

        log.info(f"Loading model {args.model.path}...")
        transformers.set_seed(seed)

        model = get_model(args)

        log.info("Done with loading model.")

        log.info(f"Loading dataset {args.dataset}...")
        dataset = Dataset.load(
            args.dataset,
            args.text_column,
            getattr(args, "label_column", None),
            batch_size=args.batch_size,
            prompt=getattr(args, "prompt", ""),
            description=getattr(args, "description", ""),
            mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
            n_shot=getattr(args, "n_shot", 5),
            few_shot_split=getattr(args, "few_shot_split", "train"),
            few_shot_prompt=getattr(args, "few_shot_prompt", None),
            im_column=getattr(args, "im_column", None),
            instruct=getattr(args, "instruct", None),
            split=args.eval_split,
            load_from_disk=args.load_from_disk,
            trust_remote_code=getattr(args, "trust_remote_code", False),
            **cache_kwargs,
        )
#	images=dataset.images
        log.info("Done with loading eval data.")

        log.info("=" * 100)
        log.info("Initializing UE estimators...")
        estimators = []
        estimators += get_ue_methods(args, model)
        log.info("Done loading UE estimators")

        if args.subsample_eval_dataset != -1:
            dataset.subsample(args.subsample_eval_dataset, seed=seed)

        generation_metrics = get_generation_metrics(args)

        ue_metrics = get_ue_metrics(args)

        builder_env_stat_calc = BuilderEnvironmentStatCalculator(model=model)
        available_stat_calculators = get_stat_calculator_names(args)

        man = UEManager(
            data=dataset,
            model=model,
            estimators=estimators,
            builder_env_stat_calc=builder_env_stat_calc,
            available_stat_calculators=available_stat_calculators,
            generation_metrics=generation_metrics,
            ue_metrics=ue_metrics,
            processors=[
                Logger(),
            ],
            ignore_exceptions=args.ignore_exceptions,
            max_new_tokens=args.max_new_tokens,
            save_stats=getattr(args, 'save_stats', []),
            log_time=getattr(args, "log_time", False),
        )

        try:
            man()
        except Exception as e:
            man.state = "failed"
            raise e
        finally:
            man.save(save_path + f"/ue_manager_seed{seed}")

        if hasattr(args, "report_to_wandb") and args.report_to_wandb:
            wandb.log({str(k) : v for k, v in man.gen_metrics})
            wandb.log({str(k) : v for k, v in man.metrics.items()})
            wandb.save(save_path + f"/ue_manager_seed{seed}")

    
    if hasattr(args, "report_to_wandb") and args.report_to_wandb:
        wandb.finish()
        


def get_ue_metrics(args):
    ue_metrics = [
        PredictionRejectionArea(),
        PredictionRejectionArea(max_rejection=0.5),
        IsotonicPCC(),
        ECE(normalize=True),
    ]
    if getattr(args, "use_claim_ue", False):
        ue_metrics += [
            ROCAUC(),
            PRAUC(),
        ]
    return ue_metrics


def get_stat_calculator_names(config):
    model_type_raw = getattr(config.model, "type", "Whitebox")
    model_type = (
        "Blackbox" if model_type_raw == "Blackbox"
        else "VisualLM" if model_type_raw == "VisualLM"
        else "Whitebox"
    )
    language = getattr(config, "language", "en")
    output_attentions = getattr(config, "output_attentions", True) and (getattr(config.model, "type", "Whitebox") != "vLLMCausalLM")
    output_hidden_states = False if getattr(config.model, "type", "Whitebox") == "vLLMCausalLM" else True
    hf_cache = getattr(config, "hf_cache", None)
    deberta_batch_size = getattr(config, "deberta_batch_size", 10)
    blackbox_supports_logprobs = model_type == "Blackbox" and getattr(
        config.model, "supports_logprobs", False
    )

    all_stat_calculators = []
    if "auto" in config.stat_calculators:
        all_stat_calculators += register_default_stat_calculators(
            model_type,
            language,
            hf_cache,
            output_attentions=output_attentions, 
            output_hidden_states=output_hidden_states,
            blackbox_supports_logprobs=blackbox_supports_logprobs,
            deberta_batch_size=deberta_batch_size,
        )

    for stat_calculator in config.stat_calculators:
        if stat_calculator == "auto":
            continue
        sc = StatCalculatorContainer(
            name=stat_calculator.name,
            cfg=stat_calculator.cfg,
            stats=stat_calculator.stats,
            dependencies=stat_calculator.dependencies,
            builder=stat_calculator.builder,
        )
        all_stat_calculators.append(sc)

    return all_stat_calculators


def get_ue_methods(config, model):
    result_estimators = []
    factory = FactoryEstimator()
    for estimator in config.estimators:
        result_estimators.append(
            factory(estimator.name, estimator.cfg if "cfg" in estimator else dict())
        )
    return result_estimators


def get_generation_metrics(args):
    log.info("=" * 100)
    log.info("Initializing generation metrics...")

    generation_metrics = getattr(args, "generation_metrics", None)
    ignore_regex = getattr(args, "source_ignore_regex", None)
    if not generation_metrics:
        result = [
            RougeMetric("rouge1"),
            RougeMetric("rouge2"),
            RougeMetric("rougeL"),
            BLEUMetric(),
            BertScoreMetric(),
            SbertMetric(),
            AccuracyMetric(
                target_ignore_regex=getattr(args, "target_ignore_regex", None),
                output_ignore_regex=getattr(args, "output_ignore_regex", None),
                normalize=getattr(args, "normalize", False),
            ),
        ]
        if args.task == "ats":
            result += [AlignScore(target_is_claims=False, source_ignore_regex=ignore_regex, source_as_target=True)]
        else:
            result += [AlignScore(target_is_claims=True)]
        if getattr(args.model, "type", "Whitebox") != "Blackbox":
            if getattr(args, "use_claim_ue", False):
                result += [
                    OpenAIFactCheck(
                        cache_path=args.cache_path,
                        language=getattr(args, "language", "en"),
                        n_threads=getattr(args, "n_threads", 1),
                    )
                ]
        if args.task == "nmt":
            result += [Comet(source_ignore_regex=ignore_regex)]
    else:
        result = []
        for metric in generation_metrics:
            metric_name = metric["name"]
            metric_class = globals()[metric_name]
            metric_args = metric.get("args", [])
            metric_kwargs = metric.get("kwargs", {})
            result.append(metric_class(*metric_args, **metric_kwargs))

    process_output_fn = getattr(args, "process_output_fn", None)
    process_target_fn = getattr(args, "process_target_fn", None)
    if process_target_fn or process_output_fn:
        if (
            getattr(args, "target_ignore_regex", None)
            or getattr(args, "output_ignore_regex", None)
            or getattr(args, "normalize", False)
        ):
            raise ValueError(
                "Specifying ignore_regex or normalize simultaneously with process scripts is not allowed."
            )

        def load_process_fn(fn_config):
            if not fn_config:
                return None
            path = get_abs_path_from_hydra_config(fn_config.path)
            module = load_external_module(path)
            return getattr(module, fn_config.fn_name)

        process_output_fn = load_process_fn(process_output_fn)
        process_target_fn = load_process_fn(process_target_fn)

        result = [
            PreprocessOutputTarget(metric, process_output_fn, process_target_fn)
            for metric in result
        ]

    if getattr(args, "multiref", False):
        # Wrap each metric in AggregatedMetric
        result = [AggregatedMetric(base_metric=metric) for metric in result]

    log.info("Done with initializing generation metrics.")

    return result


def get_model(args):
    if getattr(args.model, "type", "Whitebox") == "Blackbox":
        return get_blackbox_model(args)
    elif getattr(args.model, "type", "Whitebox") == "VisualLM":
        cache_kwargs = {
            "cache_dir": getattr(args, "hf_cache", None),
            "token": getattr(args, "hf_token", None),
        }
        return get_visual_model(args, cache_kwargs)
    elif getattr(args.model, "type", "Whitebox") == "vLLMCausalLM":
        return get_vllm_model(args)
    else:
        cache_kwargs = {
            "cache_dir": getattr(args, "hf_cache", None),
            "token": getattr(args, "hf_token", None),
        }
        return get_whitebox_model(args, cache_kwargs)


def get_blackbox_model(args):
    provider = getattr(args.model, "provider", "")
    if provider is None or provider.strip() == "":
        raise ValueError(
            "Blackbox model provider cannot be empty or None. Please specify a valid provider."
        )

    supports_logprobs = getattr(args.model, "supports_logprobs", False)

    if provider == "openai":
        openai_api_key = os.environ.get("OPENAI_API_KEY")
        if openai_api_key is None:
            raise ValueError("OpenAI API key is not set in the environment variables.")
        return BlackboxModel.from_openai(
            openai_api_key=openai_api_key,
            model_path=args.model.path,
            supports_logprobs=supports_logprobs,
        )
    elif provider == "huggingface":
        hf_api_key = os.environ.get("HUGGINGFACE_API_KEY")
        if hf_api_key is None:
            raise ValueError(
                "HuggingFace API key is not set in the environment variables."
            )
        return BlackboxModel.from_huggingface(
            hf_api_token=hf_api_key, hf_model_id=args.model.path
        )
    else:
        raise ValueError(f"Unsupported black-box model provider: {provider}")


def get_whitebox_model(args, cache_kwargs={}):
    if not "path_to_load_script" in args.model or not args.model.path_to_load_script:
        log.warning(
            "Loading model by directly passing the path to the model is deprecated and will be removed in the next release. Please use loading script instead."
        )
        log.info(f"Loading model with cache_kwargs: {cache_kwargs}")
        return WhiteboxModel.from_pretrained(
            args.model.path,
            getattr(args, "generation_params", {}),
            device_map=args.model.load_model_args.device_map,
            add_bos_token=getattr(args.model, "add_bos_token", True),
            instruct=getattr(args, "instruct", False),
            **cache_kwargs,
        )

    path_to_load_script = get_abs_path_from_hydra_config(args.model.path_to_load_script)
    load_module = load_external_module(path_to_load_script)

    load_model_args = {"model_path": args.model.path}
    load_model_args.update(args.model.load_model_args)
    base_model = load_module.load_model(**load_model_args)

    load_tok_args = {"model_path": args.model.path}
    load_tok_args.update(args.model.load_tokenizer_args)
    tokenizer = load_module.load_tokenizer(**load_tok_args)

    generation_params = GenerationParametersFactory.from_params(
        yaml_config=getattr(args, "generation_params", {}),
        native_config=base_model.generation_config.to_dict()
    )

    model = WhiteboxModel(
        base_model,
        tokenizer,
        args.model.path,
        args.model.type,
        generation_params,
        instruct=getattr(args, "instruct", False)
    )

    return model


def get_visual_model(args, cache_kwargs={}):
    if not "path_to_load_script" in args.model or not args.model.path_to_load_script:
        log.warning(
            "Loading model by directly passing the path to the model is deprecated and will be removed in the next release. Please use loading script instead."
        )
        log.info(f"Loading model with cache_kwargs: {cache_kwargs}")
        return VisualWhiteboxModel.from_pretrained(
            args.model.path,
            getattr(args, "generation_params", {}),
            device_map=args.model.load_model_args.device_map,
            add_bos_token=getattr(args.model, "add_bos_token", True),
            **cache_kwargs
        )

    path_to_load_script = get_abs_path_from_hydra_config(
            args.model.path_to_load_script
        )
    load_module = load_external_module(path_to_load_script)

    load_model_args = {'model_path': args.model.path}
    load_model_args.update(args.model.load_model_args)
    base_model = load_module.load_model(**load_model_args)

    load_tok_args = {'model_path': args.model.path}
    load_tok_args.update(args.model.load_tokenizer_args)
    tokenizer = load_module.load_tokenizer(**load_tok_args)

    load_proc_args = {'model_path': args.model.path}
    load_proc_args.update(getattr(args.model, "load_processor_args", {}))
    processor = load_processor(**load_proc_args)

    generation_params = GenerationParametersFactory.from_params(
        yaml_config=getattr(args, "generation_params", {}),
        native_config=base_model.generation_config.to_dict()
    )

    model = VisualWhiteboxModel(base_model,
                          processor,
                          args.model.path,
                          args.model.type,
                          generation_params)

    return model


def get_vllm_model(args):
    path_to_load_script = get_abs_path_from_hydra_config(
            args.model.path_to_load_script
        )
    load_module = load_external_module(path_to_load_script)

    load_model_args = {'model_path': args.model.path, 
                       'max_new_tokens': args.max_new_tokens, 
                       'logprobs': args.model.logprobs}

    load_model_args.update(args.model.load_model_args)
    base_model, sampling_params = load_module.load_model(**load_model_args)
    generation_parameters = GenerationParameters(**getattr(args, "generation_params", {}))

    model = WhiteboxModelvLLM(model=base_model, 
                              sampling_params=sampling_params,
                              generation_parameters=generation_parameters,
                              device=args.model.device,
                              instruct= getattr(args.model, "instruct", False))

    return model


def get_abs_path_from_hydra_config(path: str) -> Path:
    path = Path(path)
    if not os.path.isabs(path):
        path = hydra_config.parent / path

    return path


if __name__ == "__main__":
    main()
