#!/usr/bin/env python3

import argparse
import json
import logging
import time
from datetime import datetime

from MitreidConnect.MitreidClientApi import mitreidClientApi
from ServiceRegistryAms.PullPublish import PullPublish
from Utils.common import create_ams_response, get_log_conf
from Utils.oauth import refresh_token_grant

# Setup logger
log = logging.getLogger(__name__)


def map_token_endpoint_value(key):
    if key == "client_secret_post":
        return "SECRET_POST"
    elif key == "client_secret_basic":
        return "SECRET_BASIC"
    elif key == "client_secret_jwt":
        return "SECRET_JWT"
    elif key == "private_key_jwt":
        return "PRIVATE_KEY"
    elif key == "none":
        return "NONE"


# format_mitreid_msg gets a message from ams rciam-federation in snake_case
# and modifies it to camelCase format to be acceptable from mitreID API
def format_mitreid_msg(msg, deployment_type):
    msgNew = {}
    emails = []
    for key in msg.keys():
        components = key.split("_")
        new_key = components[0] + "".join(x.title() for x in components[1:])
        msgNew[new_key] = msg[key]
    for contact in msg["contacts"]:
        if contact["type"] == "technical" or contact["type"] == "support":
            emails.append(contact["email"])
    msgNew["contacts"] = emails
    if "serviceName" in msgNew:
        msgNew["clientName"] = msgNew.pop("serviceName")
    if "serviceDescription" in msgNew:
        msgNew["clientDescription"] = msgNew.pop("serviceDescription")
    if "idTokenTimeoutSeconds" in msgNew:
        msgNew["idTokenValiditySeconds"] = msgNew.pop("idTokenTimeoutSeconds")
    if "externalId" in msgNew:
        msgNew.pop("externalId")
    if "createdAt" in msgNew:
        if deployment_type == "create":
            msgNew.pop("createdAt")
        elif deployment_type == "edit":
            try:
                d = datetime.strptime(msgNew["createdAt"][:19], "%Y-%m-%dT%H:%M:%S")
                msgNew["createdAt"] = d.strftime("%Y-%m-%dT%H:%M:%S+0000")
            except ValueError as err:
                log.critical(err)
    if "tokenEndpointAuthMethod" in msgNew:
        msgNew["tokenEndpointAuthMethod"] = map_token_endpoint_value(msgNew["tokenEndpointAuthMethod"])
    if "aupUri" in msgNew:
        msgNew["tosUri"] = msgNew.pop("aupUri")
    if "websiteUrl" in msgNew:
        msgNew["clientUri"] = msgNew.pop("websiteUrl")
    return msgNew


# This function will gain an access token from the provided issuer and it will
# make a POST request using mitreidClient to update or create the client
#    Function update_data gets 3 arguments:
#    - messages, the new incoming messages in json
#    - issuer_url, the url of the issuer
#    - access_token
#    - deployer_name
def update_data(messages, issuer_url, access_token, deployer_name):
    pub_messages = []  # messages to be published
    mitreid_agent = mitreidClientApi(issuer_url, access_token)  # Create mitreid agent
    for msg in messages:
        log.debug("Message from ams: " + str(msg))
        service_id = msg.pop("id")  # Remove rciam service id to make request to mitreId
        external_id = ""
        client_id = ""
        try:
            response, external_id, client_id = call_mitreid(msg, mitreid_agent)
            log.info("Message received from mitreId: " + str(response))
            ams_message = create_ams_response(response, service_id, deployer_name, external_id, client_id)
        except:
            log.critical("Exception catch, return error to ams")
            ams_message = create_ams_response(
                {"status": 0, "error": "An error occurred while calling mitreId"},
                service_id,
                deployer_name,
                external_id,
                client_id,
            )
        pub_messages.append({"attributes": {}, "data": ams_message})
    return pub_messages


# Publish message to ams upstream topic. Get as arguments
# - the messages to be published
# - the ams agent to handle the operation
def publish_ams(pub_messages, ams_agent):
    if len(pub_messages) > 0:
        log.info("Publish messaged to ams")
        log.debug("Messages published to ams: " + str(pub_messages))
        ams_agent.publish(pub_messages)


# Calls mitreid depending on the deployment type provided.
# Operations handled:
# - create
# - delete
# - edit
def call_mitreid(registry_message, mitreid_agent):
    deployment_type = registry_message.pop("deployment_type")
    mitreid_msg = format_mitreid_msg(registry_message, deployment_type)
    log.debug("Formatted message for mitreId: " + str(mitreid_msg))
    response = {}
    external_id = ""
    client_id = ""
    if deployment_type == "create":
        log.info("Create new client")
        response = mitreid_agent.createClient(mitreid_msg)
        if response["status"] == 200:
            external_id = str(response["response"]["id"])
            client_id = response["response"]["clientId"]
    elif deployment_type == "delete":
        external_id = str(registry_message["external_id"])
        log.info("Delete client with id: " + str(external_id))
        response = mitreid_agent.deleteClientById(registry_message["external_id"])
    elif deployment_type == "edit":
        external_id = str(registry_message["external_id"])
        log.info("Update client with id: " + str(external_id))
        response = mitreid_agent.updateClientById(registry_message["external_id"], mitreid_msg)
        if response["status"] == 200:
            client_id = response["response"]["clientId"]
    return response, external_id, client_id


if __name__ == "__main__":
    # Get config path from arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", required=True, type=str, help="Configuration file location path")
    args = parser.parse_args()
    path = args.c
    with open(path) as json_data_file:
        config = json.load(json_data_file)

    # Get log_conf from project arguments else use the global setting
    if "log_conf" in config["mitreid"]:
        get_log_conf(config["mitreid"]["log_conf"])
    else:
        get_log_conf(config["log_conf"])

    log.info("Init ams agent")
    ams = PullPublish(config["mitreid"]["ams"])

    # Get messages
    while True:
        log.info("Pull messages from ams")
        messages, ids = ams.pull(1)
        log.info("Received " + str(len(messages)) + " messages from ams")
        if len(messages) > 0:
            log.info("Get access token from " + config["mitreid"]["issuer"])
            access_token = refresh_token_grant(
                config["mitreid"]["issuer"],
                config["mitreid"]["refresh_token"],
                config["mitreid"]["client_id"],
                config["mitreid"]["client_secret"],
            )
            ams.ack(ids)
            responses = update_data(messages, config["mitreid"]["issuer"], access_token, "")
            publish_ams(responses, ams)
        time.sleep(config["mitreid"]["ams"]["poll_interval"])
    log.info("Exit script")
