#!/usr/bin/env python3

from __future__ import print_function

import argparse
import grp
import os
import pwd
import re
import shutil
import sys

import hvac
import requests
import urllib3

import vault_certificate_deploy
from vault_certificate_deploy import base, cert_ops

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Default configuration file
default_config_file = "/etc/vault-certificate-deploy/config.cnf"
default_secret_mount_point = "cert"
default_role_name = "default"


#
# Create parser
#
def create_parser() -> argparse.ArgumentParser:
    """Create and return an argument parser"""

    parser = argparse.ArgumentParser(
        description="Certificate deploy script for HashiCorp Vault",
        epilog="Created by Robert Vojcik <robert@vojcik.net>",
    )

    parser.add_argument(
        "-c",
        dest="config_file",
        default=default_config_file,
        help=f"configuration file (default: {default_config_file})",
    )
    parser.add_argument(
        "--role-id",
        dest="role_id",
        default=False,
        help="Role id, as argument instead in config file",
    )
    parser.add_argument(
        "--secret-id",
        dest="secret_id",
        default=False,
        help="Secret id, as argument instead in config file",
    )
    parser.add_argument(
        "--vault-pki",
        dest="mount_point",
        default=False,
        help="Vault PKI secrets mount point",
    )
    parser.add_argument(
        "--cert-ttl",
        dest="cert_ttl",
        default="7776000",
        help="Certificate TTL when issuing new certificate (seconds)",
    )
    parser.add_argument(
        "--cert-role",
        dest="role_name",
        default="default",
        help="Certificate role in Vault",
    )
    parser.add_argument(
        "--cert-min-ttl",
        dest="cert_min_ttl",
        default="1209600",
        help="Certificate minimum TTL when checking certificate (seconds)",
    )
    parser.add_argument(
        "--cert-extra-options",
        dest="cert_extra_options",
        default=False,
        help="Certificate extra options as KEY1=VALUE delimited by semicolon. More info on: https://www.vaultproject.io/api/secret/pki/index.html#generate-certificate",
    )
    parser.add_argument(
        "-n", dest="cert_name", default=False, help="Certificate name to deploy"
    )
    parser.add_argument(
        "--cert-list",
        dest="cert_list",
        default=False,
        help="File containing list of certificates to deploy",
    )
    parser.add_argument(
        "--ignore-ssl-check",
        dest="ignore_ssl_check",
        default=False,
        action="store_true",
        help="Skip certificate check",
    )
    parser.add_argument(
        "--version",
        dest="version_print",
        default=False,
        action="store_true",
        help="Show version of the script",
    )
    parser.add_argument(
        "-d", dest="debug", default=False, action="store_true", help="debug mode"
    )

    return parser


def parse_cert_list_line(line: str) -> dict:
    """Parse a cert list line into a dict with name and optional per-cert options."""

    if not line.strip():
        base.eexit(1, "Invalid cert line: non-empty line with only white symbols")
    parts = line.split()
    entry = {
        "name": parts[0],
        "pki_mount_name": vault_mount_point,
        "pki_policy": role_name,
        "cert_owner": None,
        "cert_group": None,
        "cert_perms": None,
        "cert_copypath": None,
        "cert_ttl": None,
        "cert_renew_min_ttl": None,
        "extra": {},
    }
    if len(parts) > 1:
        entry["pki_mount_name"] = parts[1]
    if len(parts) > 2:
        entry["pki_policy"] = parts[2]
    if len(parts) <= 3:
        return entry
    for part in parts[3].split(";"):
        if "=" in part:
            k, v = part.split("=", 1)
            if k in entry:
                if k == "cert_perms":
                    try:
                        entry[k] = int(v, 8)
                    except ValueError:
                        base.eexit(1, f"Invalid octal permissions {v}")
                elif k in ["cert_ttl", "cert_renew_min_ttl"]:
                    try:
                        entry[k] = int(v)
                    except ValueError:
                        base.eexit(
                            1, f"Invalid cert_ttl or cert_renew_min_ttl value {v}"
                        )
                else:
                    entry[k] = v
            else:
                entry["extra"][k] = v
        else:
            base.eexit(1, "Cert arguments errors")
    if entry["cert_owner"] is not None:
        try:
            _ = pwd.getpwnam(entry["cert_owner"]).pw_uid
        except KeyError:
            base.pwrn(
                f"Unable to find user {entry['cert_owner']} in the system, fallback to default from a config file"
            )
            entry["cert_owner"] = None
    if entry["cert_group"] is not None:
        try:
            _ = grp.getgrnam(entry["cert_group"]).gr_gid
        except KeyError:
            base.pwrn(
                f"Unable to find group {entry['cert_group']} in the system, fallback to default from a config file"
            )
            entry["cert_group"] = None
    return entry


#
# Prepare variables
#
def prepare_variables() -> None:
    """Prepare variables from config file and arguments"""

    global storage_core_path
    global storage_paths
    global role_id
    global secret_id
    global verify
    global vault_mount_point
    global cert_list
    global deploy_user
    global deploy_group
    global deploy_user_id
    global deploy_group_id
    global role_name

    cert_list = []

    # Role ID
    if args.role_id:
        base.pdeb("Role id set from argument", args.debug)
        role_id = args.role_id
    else:
        try:
            base.pdeb("Role id set from config file", args.debug)
            role_id = config.parser.get("approle", "role_id")
        except:
            base.perr("Unable to determine role-id")
            base.eexit(
                1, "You have to provide role-id in configuration file or as argument"
            )

    # Secret ID
    if args.secret_id:
        base.pdeb("Secret id set from argument", args.debug)
        secret_id = args.secret_id
    else:
        try:
            base.pdeb("Secret id set from config file", args.debug)
            secret_id = config.parser.get("approle", "secret_id")
        except:
            base.perr("Unable to determine secret-id")
            base.eexit(
                1, "You have to provide secret-id in configuration file or as argument"
            )

    # Role Name
    if args.role_name:
        base.pdeb("Role name set from argument", args.debug)
        role_name = args.role_name
    else:
        role_name = default_role_name

    if config.parser.get("vault", "verify_tls") == "no":
        base.pdeb("TLS Verify disabled", args.debug)
        verify = False
    else:
        base.pdeb("TLS Verify enabled", args.debug)
        verify = True

    # Deploy User and Group
    try:
        base.pdeb("Deploy User set from config file", args.debug)
        deploy_user = config.parser.get("vault", "deploy_user")
    except:
        base.pdeb("Deploy User not find in config file, settings root", args.debug)
        deploy_user = "root"

    try:
        base.pdeb("Deploy Group set from config file", args.debug)
        deploy_group = config.parser.get("vault", "deploy_group")
    except:
        base.pdeb("Deploy Group not find in config file, settings root", args.debug)
        deploy_group = "root"

    try:
        deploy_user_id = pwd.getpwnam(deploy_user).pw_uid
    except KeyError:
        base.pwrn(f"Unable to find user {deploy_user} in the system, fallback to root")
        deploy_user_id = 0
    try:
        deploy_group_id = grp.getgrnam(deploy_group).gr_gid
    except KeyError:
        base.pwrn(
            f"Unable to find group {deploy_group} in the system, fallback to root"
        )
        deploy_group_id = 0

    # Vault mount point
    if args.mount_point:
        vault_mount_point = args.mount_point
    elif "mount_point" in config.parser.options("vault"):
        vault_mount_point = config.parser.get("vault", "mount_point")
    else:
        vault_mount_point = default_secret_mount_point
    base.pdeb("Vault secret mount point: " + vault_mount_point, args.debug)

    # Cert list and Cert Name
    # Prepare cert_list, merge config with command line if needed
    # We create simple datastructure here in dict per CA

    # Use Cert file
    if args.cert_list:
        with open(args.cert_list, "r") as fh:
            tmp = fh.read().split("\n")
            fh.close()
        # remove unwanted lines
        cert_list = [
            parse_cert_list_line(x) for x in tmp if x and not x.startswith("#")
        ]

    # User certname argument
    if args.cert_name:
        cert_list.append(parse_cert_list_line(args.cert_name))

    if not args.cert_name and not args.cert_list:
        base.eexit(
            1, "Unable to determine certificate to deploy. Use -n or --cert-list"
        )
    base.pdeb(f"Merged cert_list: {str(cert_list)}", args.debug)

    # Storage Path
    storage_core_path = config.parser.get("storage", "path")
    storage_paths = {
        config.parser.get("storage", "path") + "/" + x["pki_mount_name"]
        for x in cert_list
    }


#
# Clean certificates
#
def clean_certificates(storage: str, certificates: list[tuple[str, str]]) -> None:
    """Clean unwanted certificates"""

    # Exclude custom & global path by default
    # Global path is done by different script
    exclude = ".*/(custom|global)/.*"
    # Array of directories to be deleted
    dirs_to_delete = []

    # Find all files in storage
    for root, _, filenames in os.walk(storage):
        # Filenames
        if filenames:
            # Not excluded path
            if not re.match(exclude, root):
                ca_name, _, cert_name = root.split("/")[-3:]
                if (cert_name, ca_name) not in certificates:
                    dirs_to_delete.append(root)

    if dirs_to_delete:
        base.pout("There are some old cert dirs to be deleted")
        for delete_dir in dirs_to_delete:
            base.pout("Removing directory " + delete_dir)
            shutil.rmtree(delete_dir)


#
# MAIN PROGRAM LOOP
#
if __name__ == "__main__":

    errors_count = 0

    # Parsing arguments
    parser = create_parser()
    args = parser.parse_args()

    if args.version_print:
        print(vault_certificate_deploy.__version__)
        sys.exit(0)

    # Config file parsing
    base.pdeb(f"Loading configuration from {args.config_file}", args.debug)
    config = base.ConfigParse(args.config_file)
    cert_ops.validate_configuration(config)

    # Prepare variables from configuration and arguments
    prepare_variables()

    # Empty certificates are error state
    if len(cert_list) < 1:
        base.eexit(1, "There are no certificates to deploy.")

    # Post-hooks directory: configurable via [hooks] post_hooks_dir,
    # falls back to the historical default if the option isn't set.
    try:
        post_hooks_dir = config.parser.get("hooks", "post_hooks_dir")
        base.pdeb(f"Post-hooks directory from config: {post_hooks_dir}", args.debug)
    except Exception as _:
        post_hooks_dir = "/etc/vault-certificate-deploy/post-issue-hooks.d"
        base.pdeb(f"Post-hooks directory using default: {post_hooks_dir}", args.debug)
    post_hooks_enabled = False

    # Vault Auth, with approle
    vault = hvac.Client(url=config.parser.get("vault", "address"), verify=verify)
    try:
        auth_token = vault.auth.approle.login(role_id=role_id, secret_id=secret_id)

    except requests.ConnectTimeout as e:
        base.perr(f"Connection Timeout: {str(e)}")
        base.eexit(1, "Connection Timeout")

    except requests.ConnectionError as e:
        base.perr(f"Connection Error: {str(e)}")
        base.eexit(1, "Connection Error")

    except hvac.exceptions.InvalidRequest as e:
        base.eexit(1, f"VAULT: {str(e)}")

    # Debug output
    base.pdeb("Vault auth connection: " + str(auth_token), args.debug)

    # For every storage path
    for storage_path in storage_paths:
        path_cert = storage_path + "/certs"
        path_private = storage_path + "/private"
        # Prepare Basic Directories
        for path, perms in [(path_cert, 0o0755), (path_private, 0o0750)]:
            cert_ops.create_directory(
                path, perms, deploy_user_id, deploy_group_id, debug=args.debug
            )

    # Go through certificates and issue or renew
    certificates = []
    for cert_entry in cert_list:
        ca_name = cert_entry["pki_mount_name"]
        cert_name = cert_entry["name"]
        cert_options = cert_entry["extra"]
        cert_file = f"{storage_core_path}/{ca_name}/certs/{cert_name}/{cert_name}.crt"
        cert_priv_file = (
            f"{storage_core_path}/{ca_name}/private/{cert_name}/{cert_name}.key"
        )
        cert_role_name = cert_entry["pki_policy"]

        # Use per-cert cert_renew_min_ttl if set, otherwise use global --cert-min-ttl
        cert_min_ttl = (
            cert_entry["cert_renew_min_ttl"]
            if cert_entry["cert_renew_min_ttl"] is not None
            else int(args.cert_min_ttl)
        )

        if cert_ops.certificate_check(
            cert_name, cert_file, cert_priv_file, cert_min_ttl, args.debug
        ):
            # Certificate is OK and Valid
            base.pdeb(f"Certificate {cert_name} is ok and valid", args.debug)
        else:
            # Certificate does not exist or it's not OK

            # Handle extra options
            extra_options = {}
            if cert_options:
                extra_options = cert_options
            elif args.cert_extra_options:
                process_opts = args.cert_extra_options
                for opt in process_opts.split(";"):
                    key, value = opt.split("=")
                    extra_options[key] = value
            else:
                process_opts = False

            # Get certificate
            try:
                # Use per-cert cert_ttl if set, otherwise use global --cert-ttl
                cert_ttl = args.cert_ttl
                if cert_entry["cert_ttl"] is not None:
                    cert_ttl = cert_entry["cert_ttl"]
                basic_params = {"ttl": cert_ttl}
                # Merge basic parameters with extra options
                params = basic_params.copy()
                params.update(extra_options)
                # Generate certificate
                response = vault.secrets.pki.generate_certificate(
                    mount_point=ca_name,
                    name=cert_role_name,
                    common_name=cert_name,
                    extra_params=params,
                )
            except hvac.exceptions.InvalidRequest as e:
                base.perr(
                    f"Unable to issue {cert_name} certificate against vault server. Get {e}"
                )
                errors_count += 1
                continue

            base.pdeb("Response: " + str(response), args.debug)

            if (isinstance(response, dict)) and ("data" in response):
                secret = {
                    "data": {
                        "key": response["data"]["private_key"],
                        "ica": response["data"]["issuing_ca"],
                        "crt": response["data"]["certificate"],
                    }
                }
                certificates.append((cert_entry, secret))
            else:
                base.perr(
                    f"Unable to issue {cert_name} certificate against vault server. Get {str(response)}"
                )
                errors_count += 1

    # Deploy certificates
    for certificate_t in certificates:
        cert_name = certificate_t[0]["name"]
        ca_name = certificate_t[0]["pki_mount_name"]
        pki_policy = certificate_t[0]["pki_policy"]
        cert_owner = certificate_t[0]["cert_owner"]
        cert_group = certificate_t[0]["cert_group"]
        cert_perms = certificate_t[0]["cert_perms"]
        cert_copypath = certificate_t[0]["cert_copypath"]
        cert_ttl = certificate_t[0]["cert_ttl"]
        cert_renew_min_ttl = certificate_t[0]["cert_renew_min_ttl"]
        extra = certificate_t[0]["extra"]
        certificate = certificate_t[1]
        cert_dir_cert = f"{storage_core_path}/{ca_name}/certs/{cert_name}"
        cert_dir_private = f"{storage_core_path}/{ca_name}/private/{cert_name}"

        # Resolve per-cert owner/group, falling back to global config
        c_user = cert_owner if cert_owner is not None else deploy_user
        c_group = cert_group if cert_group is not None else deploy_group
        try:
            c_user_id = pwd.getpwnam(c_user).pw_uid
        except KeyError:
            base.pwrn(f"Unable to find user {c_user} for {cert_name}, fallback to root")
            c_user_id = 0
        try:
            c_group_id = grp.getgrnam(c_group).gr_gid
        except KeyError:
            base.pwrn(
                f"Unable to find group {c_group} for {cert_name}, fallback to root"
            )
            c_group_id = 0

        # Resolve per-cert permissions, falling back to defaults
        key_file_perm = 0o640
        cert_file_perm = cert_perms if cert_perms is not None else 0o644

        # Check Certificate. Avoid deploying wrong certificates
        if args.ignore_ssl_check is False:
            effective_min_ttl = (
                cert_renew_min_ttl
                if cert_renew_min_ttl is not None
                else int(args.cert_min_ttl)
            )
            test_result = cert_ops.certificate_validate(
                certificate_t, effective_min_ttl
            )

            if test_result is not True:
                base.pwrn(f"Certificate {cert_name} not pass the checks, skipping")
                errors_count += 1
                # Skip all and move to next cert
                continue

        # Create cert directories
        for path, perms in [(cert_dir_cert, 0o0755), (cert_dir_private, 0o0750)]:
            cert_ops.create_directory(path, perms, c_user_id, c_group_id, debug=args.debug)

        # Create certificate dir and files
        copy_files_enabled = False
        if cert_copypath:
            if cert_ops.create_directory(
                cert_copypath,
                0o755,
                c_user_id,
                c_group_id,
                exit_on_error=False,
                keep_previous_ownership=True,
                debug=args.debug,
            ):
                copy_files_enabled = True
            else:
                errors_count += 1

        # Create certificate dir and files
        for key in certificate["data"].keys():
            if key == "key":
                file_path = cert_dir_private + "/" + cert_name + "." + key
            else:
                file_path = cert_dir_cert + "/" + cert_name + "." + key

            write_cert = True
            if os.path.isfile(file_path):
                with open(file_path, "r") as fh:
                    prev_cert = fh.read()
                if prev_cert == certificate["data"][key]:
                    write_cert = False
                    base.pdeb(
                        "Not writing " + file_path + " - already exists and is the same",
                        args.debug,
                    )

            if write_cert:
                post_hooks_enabled = True
                base.pdeb("Writing " + file_path, args.debug)
                with open(file_path, "w") as fh:
                    fh.write(certificate["data"][key])

            # Change file permissions
            if key == "key":
                os.chmod(file_path, key_file_perm)
            else:
                os.chmod(file_path, cert_file_perm)

            os.chown(file_path, c_user_id, c_group_id)

            # Copy cert file to additional path if cert_copypath is set
            if copy_files_enabled:
                dst_path = cert_copypath + "/" + cert_name + "." + key
                shutil.copy2(file_path, dst_path)
                os.chown(dst_path, c_user_id, c_group_id)

    # Clean unwanted certificates
    clean_certificates(
        storage=storage_core_path,
        certificates=[(c["name"], c["pki_mount_name"]) for c in cert_list],
    )

    # Eventually run post hooks
    if post_hooks_enabled:
        errors_count += cert_ops.run_post_hooks(post_hooks_dir, args.debug)

    if errors_count > 0:
        base.eexit(1, f"There was {errors_count} errors during process")
