import os
import numpy as np
import importlib
import triton_python_backend_utils as pb_utils
from uber.ai.michelangelo.lib.model_manager._private.serde.loader.custom_model_loader import load_custom_model


def load_model(model_path):
    model_binary_path = os.path.join(model_path, "model")
    model_class = None
    with open(os.path.join(model_path, "model_class.txt")) as f:
        model_class = f.read().strip()

    if not model_class:
        raise ValueError("Model class is missing")

    module_def, _, class_name = model_class.rpartition('.')
    module = importlib.import_module(module_def)
    Model = getattr(module, class_name)
    model = load_custom_model(model_binary_path, Model, model_path)
    return model


class Model():
    def __init__(self):
        self.logger = pb_utils.Logger
        # We have pre-downloaded the model binary to the below model path.
        self.model_path = os.path.dirname(os.path.realpath(__file__))
        self.model = load_model(self.model_path)

    def execute(self, inputs):
        if not inputs:
            return self.model.predict(inputs)

        # batched inputs
        batch_size = list(inputs.values())[0].shape[0]
        batched_results = {}
        for i in range(batch_size):
            record = {k: v[i] for k, v in inputs.items()}
            result = self.model.predict(record)
            if not batched_results:
                batched_results = {
                    k: np.expand_dims(v, axis=0)
                    for k, v in result.items()
                }
            else:
                for k, v in result.items():
                    batched_results[k] = np.concatenate(
                        (batched_results[k], np.expand_dims(v, axis=0)),
                        axis=0
                    )
        # convert all string arrays to object arrays
        batched_results = {
            k: v.astype(np.object_) if v.dtype.type == np.str_ else v
            for k, v in batched_results.items()
        }
        return batched_results


def get_model():
    return Model()
