#!/usr/bin/env python3

from __future__ import print_function

import argparse
import grp
import os
import pwd
import re
import shutil
import sys

import hvac
import requests
import urllib3

import vault_certificate_deploy
from vault_certificate_deploy import base, cert_ops

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Default configuration file
default_config_file = "/etc/vault-certificate-deploy/config.cnf"
default_secret_mount_point = "cert"


#
# Create parser
#
def create_parser() -> argparse.ArgumentParser:
    """Create and return an argument parser"""

    parser = argparse.ArgumentParser(
        description="Certificate deploy script for HashiCorp Vault",
        epilog="Created by Robert Vojcik <robert@vojcik.net>",
    )

    parser.add_argument(
        "-c",
        dest="config_file",
        default=default_config_file,
        help=f"configuration file (default: {default_config_file})",
    )
    parser.add_argument(
        "--role-id",
        dest="role_id",
        default=False,
        help="Role id, as argument instead in config file",
    )
    parser.add_argument(
        "--secret-id",
        dest="secret_id",
        default=False,
        help="Secret id, as argument instead in config file",
    )
    parser.add_argument(
        "--vault-mount",
        dest="mount_point",
        default=False,
        help="Vault secrets mount point",
    )
    parser.add_argument(
        "-n", dest="cert_name", default=False, help="Certificate name to deploy"
    )
    parser.add_argument(
        "--cert-list",
        dest="cert_list",
        default=False,
        help="File containing list of certificates to deploy",
    )
    parser.add_argument(
        "--ignore-ssl-check",
        dest="ignore_ssl_check",
        default=False,
        action="store_true",
        help="Skip certificate check",
    )
    parser.add_argument(
        "--version",
        dest="version_print",
        default=False,
        action="store_true",
        help="Show version of the script",
    )
    parser.add_argument(
        "-d", dest="debug", default=False, action="store_true", help="debug mode"
    )

    return parser


#
# Parse a single cert list line into a dict
#
def parse_cert_list_line(line: str) -> dict[str, str | int | None]:
    """Parse a cert list line into a dict with name and optional per-cert options."""

    if not line.strip():
        base.eexit(1, "Invalid cert line: non-empty line with only white symbols")
    parts = line.split()
    entry = {
        "name": parts[0],
        "cert_owner": None,
        "cert_group": None,
        "cert_perms": None,
        "cert_copypath": None,
    }
    if len(parts) == 1:
        return entry
    for part in parts[1].split(";"):
        if "=" in part:
            k, v = part.split("=", 1)
            if k in entry:
                if k == "cert_perms":
                    try:
                        entry[k] = int(v, 8)
                    except ValueError:
                        base.eexit(1, f"Invalid octal permissions {v}")
                else:
                    entry[k] = v
        else:
            base.eexit(1, "Cert arguments errors")
    if entry["cert_owner"] is not None:
        try:
            _ = pwd.getpwnam(entry["cert_owner"]).pw_uid
        except KeyError:
            base.pwrn(
                f"Unable to find user {entry['cert_owner']} in the system, fallback to default from a config file"
            )
            entry["cert_owner"] = None
    if entry["cert_group"] is not None:
        try:
            _ = grp.getgrnam(entry["cert_group"]).gr_gid
        except KeyError:
            base.pwrn(
                f"Unable to find group {entry['cert_group']} in the system, fallback to default from a config file"
            )
            entry["cert_group"] = None
    return entry


#
# Prepare variables
#
def prepare_variables() -> None:
    """Prepare variables from config file and arguments"""

    global storage_path
    global role_id
    global secret_id
    global verify
    global path_cert
    global path_private
    global vault_mount_point
    global cert_list
    global deploy_user
    global deploy_group
    global deploy_user_id
    global deploy_group_id

    cert_list = []
    storage_path = config.parser.get("storage", "path")

    # Role ID
    if args.role_id:
        base.pdeb("Role id set from argument", args.debug)
        role_id = args.role_id
    else:
        try:
            base.pdeb("Role id set from config file", args.debug)
            role_id = config.parser.get("approle", "role_id")
        except Exception as _:
            base.perr("Unable to determine role-id")
            base.eexit(
                1, "You have to provide role-id in configuration file or as argument"
            )

    # Secret ID
    if args.secret_id:
        base.pdeb("Secret id set from argument", args.debug)
        secret_id = args.secret_id
    else:
        try:
            base.pdeb("Secret id set from config file", args.debug)
            secret_id = config.parser.get("approle", "secret_id")
        except Exception as _:
            base.perr("Unable to determine secret-id")
            base.eexit(
                1, "You have to provide secret-id in configuration file or as argument"
            )

    if config.parser.get("vault", "verify_tls") == "no":
        base.pdeb("TLS Verify disabled", args.debug)
        verify = False
    else:
        base.pdeb("TLS Verify enabled", args.debug)
        verify = True

    # Deploy User and Group
    try:
        base.pdeb("Deploy User set from config file", args.debug)
        deploy_user = config.parser.get("vault", "deploy_user")
    except Exception as _:
        base.pdeb("Deploy User not find in config file, settings root", args.debug)
        deploy_user = "root"

    try:
        base.pdeb("Deploy Group set from config file", args.debug)
        deploy_group = config.parser.get("vault", "deploy_group")
    except Exception as _:
        base.pdeb("Deploy Group not find in config file, settings root", args.debug)
        deploy_group = "root"

    try:
        deploy_user_id = pwd.getpwnam(deploy_user).pw_uid
    except KeyError:
        base.pwrn(f"Unable to find user {deploy_user} in the system, fallback to root")
        deploy_user_id = 0
    try:
        deploy_group_id = grp.getgrnam(deploy_group).gr_gid
    except KeyError:
        base.pwrn(
            f"Unable to find group {deploy_group} in the system, fallback to root"
        )
        deploy_group_id = 0

    # Vault mount point
    if args.mount_point:
        vault_mount_point = args.mount_point
    elif "mount_point" in config.parser.options("vault"):
        vault_mount_point = config.parser.get("vault", "mount_point")
    else:
        vault_mount_point = default_secret_mount_point
    base.pdeb("Vault secret mount point: " + vault_mount_point, args.debug)

    # Prepare destination dir
    path_cert = storage_path + "/certs/"
    base.pdeb("Path certificates: " + path_cert, args.debug)
    path_private = storage_path + "/private/"
    base.pdeb("Path keys: " + path_private, args.debug)

    # Use Cert file
    if args.cert_list:
        with open(args.cert_list, "r") as fh:
            tmp = fh.read().split("\n")

        # remove unwanted lines, parse each into a dict
        cert_list = [
            parse_cert_list_line(x) for x in tmp if x.strip() and not x.startswith("#")
        ]

    # User certname argument
    if args.cert_name:
        cert_list.append(parse_cert_list_line(args.cert_name))

    if not args.cert_name and not args.cert_list:
        base.eexit(
            1, "Unable to determine certificate to deploy. Use -n or --cert-list"
        )

    base.pdeb(f"Merged cert_list: {str(cert_list)}", args.debug)


#
# Clean certificates
#
def clean_certificates(storage: str, certificates: list[str]) -> None:
    """Clean unwanted certificates"""

    # Exclude custom path by default
    exclude = ".*/custom/.*"
    # Array of directories to be deleted
    dirs_to_delete = []

    # Find all files in storage
    for root, _, filenames in os.walk(storage):
        # Filenames exists
        if filenames:
            # Every file test
            delete_it = True
            # Not excluded path
            if not re.match(exclude, root):
                # Check file against deployed certs
                for crt in certificates:
                    test_str = "^.*/" + crt + "$"
                    if re.match(test_str, root):
                        # We found it, don't delete it
                        delete_it = False

                if delete_it:
                    # Delete root dir of the file
                    if root not in dirs_to_delete:
                        dirs_to_delete.append(root)

    if dirs_to_delete:
        base.pout("There are some old cert dirs to be deleted")
        for delete_dir in dirs_to_delete:
            base.pout("Removing directory " + delete_dir)
            shutil.rmtree(delete_dir)


#
# MAIN PROGRAM LOOP
#
if __name__ == "__main__":

    errors_count = 0

    # Parsing arguments
    parser = create_parser()
    args = parser.parse_args()

    if args.version_print:
        print(vault_certificate_deploy.__version__)
        sys.exit(0)

    # Config file parsing
    base.pdeb(f"Loading configuration from {args.config_file}", args.debug)
    config = base.ConfigParse(args.config_file)
    cert_ops.validate_configuration(config)

    # Prepare variables from configuration and arguments
    prepare_variables()

    # Empty certificates are error state
    if len(cert_list) < 1:
        base.eexit(1, "There are no certificates to deploy.")

    # Post-hooks directory: configurable via [hooks] post_hooks_dir,
    # falls back to the historical default if the option isn't set.
    try:
        post_hooks_dir = config.parser.get("hooks", "post_hooks_dir")
        base.pdeb(f"Post-hooks directory from config: {post_hooks_dir}", args.debug)
    except Exception as _:
        post_hooks_dir = "/etc/vault-certificate-deploy/post-hooks.d"
        base.pdeb(f"Post-hooks directory using default: {post_hooks_dir}", args.debug)
    post_hooks_enabled = False

    # Vault Auth, with approle
    vault = hvac.Client(url=config.parser.get("vault", "address"), verify=verify)
    try:
        auth_token = vault.auth.approle.login(role_id=role_id, secret_id=secret_id)

    except requests.ConnectTimeout as e:
        base.perr(f"Connection Timeout: {str(e)}")
        base.eexit(1, "Connection Timeout")

    except requests.ConnectionError as e:
        base.perr(f"Connection Error: {str(e)}")
        base.eexit(1, "Connection Error")

    except hvac.exceptions.InvalidRequest as e:
        base.eexit(1, f"VAULT: {str(e)}")

    # Debug output
    base.pdeb("Vault auth connection: " + str(auth_token), args.debug)

    # Read secrets into list of secrets
    certificates = []
    for cert_entry in cert_list:
        cert_name = cert_entry["name"]
        base.pdeb("Retrieving secret " + cert_name, args.debug)
        secret = vault.read(vault_mount_point + "/" + cert_name)
        if secret is None:
            secret = vault.secrets.kv.v2.read_secret(
                mount_point=vault_mount_point, path=cert_name
            )
            if secret is not None:
                secret = secret["data"]

        if secret is not None:
            base.pdeb(
                f"Secret: {str(secret)} with keys {str(secret.keys())}",
                args.debug,
            )
            certificates.append((cert_entry, secret))
        else:
            base.perr(
                "Unable to retrieve secret " + vault_mount_point + "/" + cert_name
            )
            # Increment error count
            errors_count += 1

    # Prepare Basic Directories
    for path, perms in [(path_cert, 0o0755), (path_private, 0o0750)]:
        cert_ops.create_directory(path, perms, deploy_user_id, deploy_group_id)

    # Deploy certificates
    for certificate_t in certificates:
        cert_name = certificate_t[0]["name"]
        cert_owner = certificate_t[0]["cert_owner"]
        cert_group = certificate_t[0]["cert_group"]
        cert_perms = certificate_t[0]["cert_perms"]
        cert_copypath = certificate_t[0]["cert_copypath"]
        certificate = certificate_t[1]
        cert_dir_cert = path_cert + "/" + cert_name
        cert_dir_private = path_private + "/" + cert_name

        # Resolve per-cert owner/group, falling back to global config
        c_user = cert_owner if cert_owner is not None else deploy_user
        c_group = cert_group if cert_group is not None else deploy_group
        try:
            c_user_id = pwd.getpwnam(c_user).pw_uid
        except KeyError:
            base.pwrn(f"Unable to find user {c_user} for {cert_name}, fallback to root")
            c_user_id = 0
        try:
            c_group_id = grp.getgrnam(c_group).gr_gid
        except KeyError:
            base.pwrn(
                f"Unable to find group {c_group} for {cert_name}, fallback to root"
            )
            c_group_id = 0

        # Resolve per-cert permissions, falling back to defaults
        key_file_perm = 0o640
        cert_file_perm = cert_perms if cert_perms is not None else 0o644

        # Check Certificate. Avoid deploying wrong certificates
        if args.ignore_ssl_check is False:
            test_result = cert_ops.certificate_validate(certificate_t)

            if test_result is not True:
                base.pwrn(f"Certificate {str(cert_name)} not pass the checks, skipping")
                errors_count += 1
                # Skip all and move to next cert
                continue

        # Create cert directories
        for path, perms in [(cert_dir_cert, 0o0755), (cert_dir_private, 0o0750)]:
            cert_ops.create_directory(path, perms, c_user_id, c_group_id)

        # Add bundlekey
        bundlekey_keys = (certificate["data"]["bundle"], certificate["data"]["key"])
        certificate["data"]["bundlekey"] = "\n".join(bundlekey_keys)

        # Create certificate dir and files
        copy_files_enabled = False
        if cert_copypath:
            if cert_ops.create_directory(
                cert_copypath,
                0o755,
                c_user_id,
                c_group_id,
                exit_on_error=False,
                keep_previous_ownership=True,
            ):
                copy_files_enabled = True
            else:
                errors_count += 1

        for key in certificate["data"].keys():
            if key in ["key", "bundlekey"]:
                file_path = cert_dir_private + "/" + cert_name + "." + key
            else:
                file_path = cert_dir_cert + "/" + cert_name + "." + key

            write_cert = True
            if os.path.isfile(file_path):
                with open(file_path, "r") as fh:
                    prev_cert = fh.read()
                if prev_cert == certificate["data"][key]:
                    write_cert = False
                    base.pdeb(
                        "Not writing "
                        + file_path
                        + " - already exists and is the same",
                        args.debug,
                    )

            if write_cert:
                post_hooks_enabled = True
                base.pdeb("Writing " + file_path, args.debug)
                with open(file_path, "w") as fh:
                    fh.write(certificate["data"][key])

            # Change file permissions
            if key in ["key", "bundlekey"]:
                os.chmod(file_path, key_file_perm)
            else:
                os.chmod(file_path, cert_file_perm)

            os.chown(file_path, c_user_id, c_group_id)

            # Copy cert file to additional path if cert_copypath is set
            if copy_files_enabled:
                base.pdeb("Copying into " + cert_copypath, args.debug)
                dst_path = cert_copypath + "/" + cert_name + "." + key
                shutil.copy2(file_path, dst_path)
                os.chown(dst_path, c_user_id, c_group_id)

    # Clean unwanted certificates
    clean_certificates(
        storage=storage_path, certificates=[e["name"] for e in cert_list]
    )

    # Eventually run post hooks
    if post_hooks_enabled:
        errors_count += cert_ops.run_post_hooks(post_hooks_dir, args.debug)

    if errors_count > 0:
        base.eexit(1, f"There was {errors_count} errors during process")
