#!python

import argparse
import json
import logging
import time

from Keycloak.KeycloakClientApi import KeycloakClientApi
from ServiceRegistryAms.PullPublish import PullPublish
from Utils.common import create_ams_response, get_keycloak_issuer, get_log_conf
from Utils.oauth import client_credentials_grant, refresh_token_grant

# Setup logger
log = logging.getLogger(__name__)

# Set default JSON Data
jsonOidcTemplate = '{"attributes":{"client_credentials.use_refresh_token":"false","oauth2.device.authorization.grant.enabled":"false","standard.token.exchange.enabled":false,"oidc.ciba.grant.enabled":"false","refresh.token.max.reuse":"0","revoke.refresh.token":"false","use.jwks.string":"false","use.jwks.url":"false","use.refresh.tokens":"false"},"directAccessGrantsEnabled":false,"implicitFlowEnabled":false,"publicClient":false,"serviceAccountsEnabled":false,"standardFlowEnabled":false,"webOrigins":["+"]}'
jsonSamlTemplate = (
    '{"attributes":{"saml.auto.updated":"true","saml.refresh.period":"3600","saml.skip.requested.attributes": "true"}}'
)

# Set Client Credentials Mapper JSON DATA
clientCredentialsMapper = '{"name":"preferred_username","protocol":"openid-connect","protocolMapper":"oidc-usermodel-property-mapper","consentRequired":false,"config":{"userinfo.token.claim":"false","user.attribute":"username","id.token.claim":"false","access.token.claim":"true","claim.name":"preferred_username","introspection.response.claim":"false","jsonType.label":"String"}}'


def map_token_endpoint_value(key):
    if key == "client_secret_post":
        return "client-secret"
    elif key == "client_secret_basic":
        return "client-secret"
    elif key == "client_secret_jwt":
        return "client-secret-jwt"
    elif key == "private_key_jwt":
        return "client-jwt"
    elif key == "none":
        return "client-secret"


# format_keycloak_msg gets a message from ams rciam-federation in snake_case
# and modifies it to camelCase format to be acceptable from Keycloak API
def format_keycloak_msg(msg, realm_default_client_scopes, keycloak_config):
    if "protocol" in msg:
        # Options for OIDC clients
        if msg["protocol"] == "oidc":
            new_msg = json.loads(jsonOidcTemplate)
            new_msg["defaultClientScopes"] = realm_default_client_scopes
            new_msg["protocol"] = "openid-connect"
            if "oidc_consent" in keycloak_config:
                new_msg["consentRequired"] = keycloak_config["oidc_consent"]
            else:
                new_msg["consentRequired"] = False
            if "client_id" in msg and msg["client_id"]:
                new_msg["clientId"] = msg.pop("client_id")
            if "redirect_uris" in msg:
                new_msg["redirectUris"] = msg.pop("redirect_uris")
            if "scope" in msg:
                new_msg["optionalClientScopes"] = msg.pop("scope")
                if "openid" in new_msg["optionalClientScopes"]:
                    new_msg["optionalClientScopes"].remove("openid")
            if "grant_types" in msg:
                for grant_type in msg["grant_types"]:
                    if grant_type == "authorization_code":
                        new_msg["standardFlowEnabled"] = True
                    if grant_type == "client_credentials":
                        new_msg["serviceAccountsEnabled"] = True
                    if grant_type == "urn:ietf:params:oauth:grant-type:token-exchange":
                        new_msg["attributes"]["standard.token.exchange.enabled"] = True
                    if grant_type == "urn:ietf:params:oauth:grant-type:device_code":
                        new_msg["attributes"]["oauth2.device.authorization.grant.enabled"] = True
                    if grant_type == "implicit":
                        new_msg["implicitFlowEnabled"] = True
            if "token_endpoint_auth_method" in msg:
                if msg["token_endpoint_auth_method"] == "none":
                    new_msg["publicClient"] = True
                new_msg["clientAuthenticatorType"] = map_token_endpoint_value(msg.pop("token_endpoint_auth_method"))
            if "client_secret" in msg:
                new_msg["secret"] = msg.pop("client_secret")
            if "token_endpoint_auth_signing_alg" in msg:
                new_msg["attributes"]["token.endpoint.auth.signing.alg"] = msg.pop("token_endpoint_auth_signing_alg")
            if "jwks" in msg:
                new_msg["attributes"]["use.jwks.string"] = "true"
                new_msg["attributes"]["jwks.string"] = json.dumps(msg.pop("jwks"))
            if "jwks_uri" in msg:
                new_msg["attributes"]["use.jwks.url"] = True
                new_msg["attributes"]["jwks.url"] = msg.pop("jwks_uri")
            if "refresh_token_validity_seconds" in msg:
                new_msg["attributes"]["client.offline.session.max.lifespan"] = str(
                    msg["refresh_token_validity_seconds"]
                )
                if "reuse_refresh_token" in msg:
                    rotate_refresh_token = str(not msg.pop("reuse_refresh_token"))
                    new_msg["attributes"]["revoke.refresh.token"] = rotate_refresh_token.lower()
            if "code_challenge_method" in msg:
                new_msg["attributes"]["pkce.code.challenge.method"] = msg.pop("code_challenge_method")
            if "access_token_validity_seconds" in msg:
                if msg["access_token_validity_seconds"] < 60:
                    new_msg["attributes"]["access.token.lifespan"] = "60"
                else:
                    new_msg["attributes"]["access.token.lifespan"] = str(msg.pop("access_token_validity_seconds"))
            if "id_token_timeout_seconds" in msg:
                new_msg["attributes"]["id.token.lifespan"] = str(msg.pop("id_token_timeout_seconds"))
            if "device_code_validity_seconds" in msg:
                new_msg["attributes"]["oauth2.device.code.lifespan"] = str(msg.pop("device_code_validity_seconds"))
        # Options for SAML clients
        elif msg["protocol"] == "saml":
            new_msg = json.loads(jsonSamlTemplate)
            new_msg["protocol"] = msg.pop("protocol")
            if "saml_consent" in keycloak_config:
                new_msg["consentRequired"] = keycloak_config["saml_consent"]
            else:
                new_msg["consentRequired"] = True
            if "metadata_url" in msg:
                new_msg["attributes"]["saml.metadata.url"] = msg.pop("metadata_url")
            if "entity_id" in msg:
                new_msg["clientId"] = msg.pop("entity_id")
            if "requested_attributes" in msg:
                client_default_client_scopes, new_msg["protocolMappers"] = add_saml_scopes_and_mappers(
                    msg["requested_attributes"]
                )
                new_msg["defaultClientScopes"] = list(
                    sorted(set(client_default_client_scopes + realm_default_client_scopes))
                )
    # Options for any type of client
    emails = []
    if "service_name" in msg:
        new_msg["name"] = msg.pop("service_name")
    if "logo_uri" in msg:
        new_msg["attributes"]["logoUri"] = msg.pop("logo_uri")
    if "website_url" in msg:
        new_msg["baseUrl"] = msg.pop("website_url")
    if "service_description" in msg:
        new_msg["description"] = msg.pop("service_description")
    if "country" in msg:
        new_msg["attributes"]["country"] = msg.pop("country").upper()
    if "policy_uri" in msg:
        new_msg["attributes"]["policyUri"] = msg.pop("policy_uri")
    if "aup_uri" in msg:
        new_msg["attributes"]["tosUri"] = msg.pop("aup_uri")
    if "contacts" in msg:
        for contact in msg["contacts"]:
            if contact["type"] == "technical" or contact["type"] == "support":
                emails.append(contact["email"])
        # Remove duplicate entries
        emails = list(set(emails))
        new_msg["attributes"]["contacts"] = ",".join(emails)

    return new_msg


# This function will gain an access token from the provided issuer and it will
# make a POST request using KeycloakClientApi to update or create the client
#    Function process_data gets 4 arguments:
#    - messages, the new incoming messages in json
#    - access_token
#    - keycloak_config, configuration file for Keycloak
def process_data(messages, access_token, keycloak_config):
    auth_server = keycloak_config["auth_server"]
    realm = keycloak_config["realm"]
    deployer_name = ""
    # messages to be published
    pub_messages = []
    # Create Keycloak agent
    keycloak_agent = KeycloakClientApi(auth_server, realm, access_token)
    for msg in messages:
        log.debug("Message from ams: " + str(msg))
        # Remove rciam service id to make request to Keycloak
        service_id = msg.pop("id")
        external_id = ""
        client_id = ""
        try:
            response, external_id, client_id = deploy_to_keycloak(msg, keycloak_agent, keycloak_config)
            log.info("Message received from Keycloak: " + str(response))
            ams_message = create_ams_response(response, service_id, deployer_name, external_id, client_id)
        except:
            log.critical("Exception catch, return error to ams")
            ams_message = create_ams_response(
                {"status": 0, "error": "An error occurred while calling Keycloak"},
                service_id,
                deployer_name,
                external_id,
                client_id,
            )
        pub_messages.append({"attributes": {}, "data": ams_message})
    return pub_messages


# Publish message to ams upstream topic. Get as arguments
# - the messages to be published
# - the ams agent to handle the operation
def publish_ams(pub_messages, ams_agent):
    if len(pub_messages) > 0:
        log.info("Publish messaged to ams")
        log.debug("Messages published to ams: " + str(pub_messages))
        ams_agent.publish(pub_messages)


# Create the optional client scopes of the client
def create_client_scopes(agent, client_uuid, client_config):
    realm_client_scopes = agent.sync_realm_client_scopes()
    new_optional_client_scopes = client_config["optionalClientScopes"]

    # Custom scopes that are not created in Keycloak
    create_client_scopes = list(set(new_optional_client_scopes) - set(realm_client_scopes.keys()))
    parametric_scope_delimiter = "?value="
    for scope in create_client_scopes:
        if parametric_scope_delimiter in scope:
            # Ignore parametric scopes
            log.info("The scope '" + scope + "' will be ignored.")
            new_optional_client_scopes.remove(scope)
        else:
            # Create custom scope
            agent.create_realm_oidc_client_scopes(scope)

    # Get updated client scopes
    realm_client_scopes = agent.sync_realm_client_scopes()

    for add_scope in create_client_scopes:
        agent.add_client_scope_by_id(client_uuid, realm_client_scopes[add_scope])


# Update the client scopes of the client
def update_client_scopes(agent, client_uuid, new_client_config, current_client_config):
    protocol = current_client_config["protocol"]
    if protocol == "saml":
        key = "defaultClientScopes"
    else:
        key = "optionalClientScopes"
    realm_client_scopes = agent.sync_realm_client_scopes()
    current_client_scopes = current_client_config[key]
    new_client_scopes = new_client_config[key]

    if protocol == "openid-connect":
        # Custom scopes that are not created in Keycloak
        create_client_scopes = list(set(new_client_scopes) - set(realm_client_scopes.keys()))
        parametric_scope_delimiter = "?value="
        for scope in create_client_scopes:
            if parametric_scope_delimiter in scope:
                # Ignore parametric scopes
                log.info("The scope '" + scope + "' will be ignored.")
                new_client_scopes.remove(scope)
            else:
                # Create custom scope
                agent.create_realm_oidc_client_scopes(scope)

        # Get updated client scopes
        realm_client_scopes = agent.sync_realm_client_scopes()

    remove_client_scopes = list(set(current_client_scopes) - set(new_client_scopes))
    add_client_scopes = list(set(new_client_scopes) - set(current_client_scopes))

    for remove_scope in remove_client_scopes:
        agent.remove_client_scope_by_id(client_uuid, realm_client_scopes[remove_scope], protocol)

    for add_scope in add_client_scopes:
        agent.add_client_scope_by_id(client_uuid, realm_client_scopes[add_scope], protocol)


def add_saml_scopes_and_mappers(requested_attributes):
    scopes = []
    mappers = []
    for attribute in requested_attributes:
        scopes.append(attribute["friendly_name"])
        if attribute["type"] == "custom":
            mapper_object = {
                "name": attribute["friendly_name"],
                "protocol": "saml",
                "protocolMapper": "saml-user-attribute-mapper",
                "config": {
                    "attribute.nameformat": "URI Reference",
                    "user.attribute": attribute["friendly_name"],
                    "friendly.name": attribute["friendly_name"],
                    "attribute.name": attribute["name"],
                },
            }
            mappers.append(mapper_object)
    return scopes, mappers


# Update the optional client scopes of the client
def update_service_account(agent, client_uuid, current_client_config, keycloak_config):
    service_account_profile = agent.get_service_account_user(client_uuid)
    agent.add_mapper(client_uuid,json.loads(clientCredentialsMapper))
    agent.update_user(service_account_profile["response"], current_client_config, keycloak_config)


# Calls Keycloak depending on the deployment type provided.
# Operations handled:
# - create
# - delete
# - edit
def deploy_to_keycloak(registry_message, keycloak_agent, keycloak_config):
    deployment_type = registry_message.pop("deployment_type")
    protocol = registry_message["protocol"]
    if protocol == "oidc":
        protocol = "openid-connect"
    default_client_scopes = []
    realm_default_client_scopes = keycloak_agent.get_realm_default_client_scopes(protocol)
    for scope in realm_default_client_scopes:
        default_client_scopes.append(scope["name"])
    keycloak_msg = format_keycloak_msg(registry_message, default_client_scopes, keycloak_config)
    log.debug("Formatted message for Keycloak: " + str(keycloak_msg))
    response = {}
    external_id = ""
    client_id = ""
    if deployment_type == "create":
        log.info("Create new client")
        response = keycloak_agent.create_client(keycloak_msg)
        if response["status"] == 201:
            client_id = response["response"]["clientId"]
            if "id" in response["response"]:
                external_id = response["response"]["id"]
            else:
                response_external_id = keycloak_agent.get_client_by_id(client_id)
                external_id = response_external_id["response"]["id"]
        if protocol == "openid-connect":
            create_client_scopes(keycloak_agent, external_id, keycloak_msg)
            if response["response"]["serviceAccountsEnabled"]:
                update_service_account(
                    keycloak_agent, external_id, response["response"], keycloak_config["service_account"]
                )
    elif deployment_type == "delete":
        client_id = keycloak_msg["clientId"]
        external_id = registry_message.get("external_id", "")
        log.info("Delete client with id: " + str(client_id))
        response = keycloak_agent.delete_client(client_id)
    elif deployment_type == "edit":
        client_id = keycloak_msg["clientId"]
        log.info("Update client with id: " + str(client_id))
        response = keycloak_agent.update_client(client_id, keycloak_msg)
        external_id = response["response"]["id"]
        if protocol == "openid-connect":
            update_client_scopes(keycloak_agent, external_id, keycloak_msg, response["response"])
            if response["response"]["serviceAccountsEnabled"]:
                update_service_account(
                    keycloak_agent, external_id, response["response"], keycloak_config["service_account"]
                )
            client_authz_permissions_response = keycloak_agent.get_client_authz_permissions(external_id)
            if (
                client_authz_permissions_response["response"]["enabled"] == False
                and keycloak_msg["attributes"]["standard.token.exchange.enabled"] == True
            ):
                keycloak_agent.update_client_authz_permissions(external_id, "enable")
            elif (
                client_authz_permissions_response["response"]["enabled"] == True
                and keycloak_msg["attributes"]["standard.token.exchange.enabled"] == False
            ):
                keycloak_agent.update_client_authz_permissions(external_id, "disable")
        if protocol == "saml":
            update_client_scopes(keycloak_agent, external_id, keycloak_msg, response["response"])
    return response, external_id, client_id


if __name__ == "__main__":
    # Get config path from arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", required=True, type=str, help="Configuration file location path")
    args = parser.parse_args()
    path = args.c
    with open(path) as json_data_file:
        config = json.load(json_data_file)

    # Get log_conf from project arguments else use the global setting
    if "log_conf" in config["keycloak"]:
        get_log_conf(config["keycloak"]["log_conf"])
    else:
        get_log_conf(config["log_conf"])

    log.info("Init ams agent")
    ams = PullPublish(config["keycloak"]["ams"])

    # Get messages
    while True:
        log.info("Pull messages from ams")
        messages, ids = ams.pull(1)
        log.info("Received " + str(len(messages)) + " messages from ams")
        if len(messages) > 0:
            log.info("Get access token from " + config["keycloak"]["auth_server"])
            if "refresh_token" in config["keycloak"]:
                access_token = refresh_token_grant(
                    get_keycloak_issuer(config["keycloak"]),
                    config["keycloak"]["refresh_token"],
                    config["keycloak"]["client_id"],
                    config["keycloak"]["client_secret"],
                )
            else:
                access_token = client_credentials_grant(
                    get_keycloak_issuer(config["keycloak"]),
                    config["keycloak"]["client_id"],
                    config["keycloak"]["client_secret"],
                )
            ams.ack(ids)
            responses = process_data(messages, access_token, config["keycloak"])
            publish_ams(responses, ams)
        time.sleep(config["keycloak"]["ams"]["poll_interval"])
    log.info("Exit script")
