import os
import numpy as np
import importlib
import triton_python_backend_utils as pb_utils
from uber.ai.michelangelo.lib.model_manager._private.serde.loader.custom_model_loader import load_custom_model


def load_model(model_path):
    model_binary_path = os.path.join(model_path, "model")
    model_class = None
    with open(os.path.join(model_path, "model_class.txt")) as f:
        model_class = f.read().strip()

    if not model_class:
        raise ValueError("Model class is missing")

    module_def, _, class_name = model_class.rpartition('.')
    module = importlib.import_module(module_def)
    Model = getattr(module, class_name)
    model = load_custom_model(model_binary_path, Model, model_path)
    return model


class Model():
    def __init__(self):
        self.logger = pb_utils.Logger
        # We have pre-downloaded the model binary to the below model path.
        self.model_path = os.path.dirname(os.path.realpath(__file__))
        self.model = load_model(self.model_path)

    def execute(self, inputs):
        results = self.model.predict(inputs)
        # convert all string arrays to object arrays
        results = {
            k: v.astype(np.object_) if v.dtype.type == np.str_ else v
            for k, v in results.items()
        }
        return results


def get_model():
    return Model()
