#!/usr/bin/python3
#
# A simple server-side libvirt-interfacing helper that can reserve domains,
# install them, clone their disks, etc., to enable their use as transient
# testing systems by multiple users, in a safe race-free way.
# Part of the ATEX project - https://github.com/RHSecurityCompliance/atex.

import argparse
import atexit
import ctypes
import errno
import fcntl
import json
import logging
import os
import pty
import random
import re
import select
import struct
import subprocess
import sys
import tempfile
import time
import xml.etree.ElementTree as ET
from logging.handlers import SysLogHandler
from pathlib import Path

libc = ctypes.CDLL(None, use_errno=True)
# from /usr/include/linux/prctl.h:
PR_SET_NAME = 15
# prepare args / return for prctl(2)
libc.prctl.argtypes = (
    ctypes.c_int, ctypes.c_char_p,
    ctypes.c_ulong, ctypes.c_ulong, ctypes.c_ulong,
)
libc.prctl.restype = ctypes.c_int

# URI passed via CLI args
libvirt_uri = None
# Path, used to synchronize multiple instances of this helper
state_dir = None
# dict of names (keys) and OS file descriptors (values) opened by lock()
state_locks = {}

# these will be re-opened on descriptors beyond 2
clean_stdout = clean_stderr = None


def set_procname(name):
    name_bytes = name[:15].encode() + b"\x00"
    if libc.prctl(PR_SET_NAME, name_bytes, 0, 0, 0) != 0:
        err = ctypes.get_errno()
        raise OSError(err, os.strerror(err))


def lock(name, op=fcntl.F_SETLK):
    if name in state_locks:
        return True
    locks_dir = state_dir / "locks"
    # see 'struct flock', which we pack here
    lock_data = struct.pack("@hhqqi", fcntl.F_WRLCK, os.SEEK_SET, 0, 0, 0)
    fd = os.open(locks_dir / name, os.O_WRONLY | os.O_CREAT, 0o666)
    try:
        fcntl.fcntl(fd, op, lock_data)
    except BlockingIOError:
        os.close(fd)
        return False
    state_locks[name] = fd
    return True


def unlock(name):
    try:
        fd = state_locks.pop(name)
    except KeyError:
        pass
    else:
        os.close(fd)


def getlock(name):
    # when used with F_GETLK, this does not actually lock the fd,
    # but the struct is still needed to see if there *would* be a locking
    # conflict (another process holds the lock)
    # - this doesn't make sense for our 'state_locks' (we already hold the
    #   locks) so those would *always* get F_UNLCK
    lock_data = struct.pack("@hhqqi", fcntl.F_WRLCK, os.SEEK_SET, 0, 0, 0)

    # if held by us
    if name in state_locks:
        return -1
    else:
        locks_dir = state_dir / "locks"
        fd = os.open(locks_dir / name, os.O_WRONLY | os.O_CREAT, 0o666)
        try:
            result = fcntl.fcntl(fd, fcntl.F_GETLK, lock_data)
        finally:
            os.close(fd)

    l_type, _, _, _, l_pid = struct.unpack("@hhqqi", result)
    return None if l_type == fcntl.F_UNLCK else l_pid


def virsh(*args):
    connect = ("-c", libvirt_uri) if libvirt_uri else ()
    proc = subprocess.run(
        ("virsh", "-q", *connect, *args),
        text=True,
        stdin=subprocess.DEVNULL,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        timeout=3600,
    )
    return (proc.returncode, proc.stdout)


def cmd_ping(_logger, reply, _args):
    reply(True, "atex-virt-helper v1 pong")


def cmd_setname(_logger, reply, args):
    name = args["name"]
    if not name.strip():
        reply(False, "name must not be empty")
        return
    set_procname(name)
    reply(True)


def cmd_virsh(_logger, reply, args):
    code, output = virsh(*args["args"])
    reply(code == 0, output)


def cmd_reserve(logger, reply, args):
    code, output = virsh("list", "--all", "--persistent", "--name")
    if code != 0:
        msg = f"virsh list failed with {code}: {output}"
        logger.error(msg)
        reply(False, msg)
        return

    if not (domain_list := output.rstrip("\n")):
        reply(False, "no persistent domains found")
        return

    domains = domain_list.split("\n")

    if (regex := args.get("filter")) is not None:
        domains = [name for name in domains if re.fullmatch(regex, name)]
        if not domains:
            reply(False, "no domain matches the filter")
            return

    # don't try to lock domains already locked
    domains = [name for name in domains if name not in state_locks]

    random.shuffle(domains)
    for domain in domains:
        if lock(domain):
            reply(True, domain=domain)
            return

    reply(False, "no domain could be reserved")


def cmd_release(_logger, reply, args):
    domain = args["domain"]
    if domain not in state_locks:
        reply(False, "domain not locked")
        return
    unlock(domain)
    reply(True)


def cmd_reservations(logger, reply, _args):
    code, output = virsh("list", "--all", "--persistent", "--name")
    if code != 0:
        msg = f"virsh list failed with {code}: {output}"
        logger.error(msg)
        reply(False, msg)
        return

    if not (domain_list := output.rstrip("\n")):
        reply(False, "no persistent domains found")
        return

    domains = domain_list.split("\n")

    status = {}
    for domain in domains:
        lock_pid = getlock(domain)
        if lock_pid == -1:
            status[domain] = "(reserved for us)"
        elif lock_pid is not None:
            try:
                comm = (Path("/proc") / str(lock_pid) / "comm").read_text().rstrip("\n")
            except FileNotFoundError:
                status[domain] = f"(pid {lock_pid} disappeared)"
                continue
            if not comm.strip():
                status[domain] = "(empty name)"
            else:
                status[domain] = comm
        else:
            status[domain] = "(free)"
    reply(True, domains=status)


def cmd_upload(logger, reply, args):
    name = args["name"]
    if "/" in name:
        reply(False, "name cannot have '/'")
        return
    length = args["length"]
    with open(name, "wb") as f:
        while length > 0:
            chunk = sys.stdin.buffer.read(min(length, 10240))
            if not chunk:
                logger.error("unexpected EOF during upload")
                return
            f.write(chunk)
            length -= len(chunk)
    reply(True)


def cmd_virt_install(logger, reply, args):
    connect = ("--connect", libvirt_uri) if libvirt_uri else ()
    virt_install_args = args["args"]

    if args.get("destroy_on_error"):
        try:
            idx = virt_install_args.index("-n")
        except ValueError:
            try:
                idx = virt_install_args.index("--name")
            except ValueError:
                reply(False, "destroy_on_error requested, but no -n/--name given in args")
                return

        def destroy_if_requested():
            name = virt_install_args[idx+1]
            logger.debug(f"attempting to 'virsh destroy {name}'")
            virsh("destroy", name)
    else:
        def destroy_if_requested():
            pass

    # simulate a TTY for virt-install, so that it can output console
    # via an (internal) 'virsh console' call
    proc = None
    m_fd, s_fd = pty.openpty()
    try:
        try:
            proc = subprocess.Popen(
                ("virt-install", *connect, *args["args"]),
                stdin=s_fd,
                stdout=s_fd,
                stderr=subprocess.STDOUT,
            )
        finally:
            os.close(s_fd)

        # use poll() to check up on clean_stderr regularly, and trigger an
        # exception if it goes away (ssh disconnects), to avoid virt-install
        # being stuck in an interactive session forever (after ssh discon's)
        poller = select.poll()
        poller.register(m_fd, select.POLLIN)
        # 0 flags still wakes up for POLLERR/POLLHUP
        poller.register(clean_stderr, 0)

        while True:
            events = poller.poll()
            for fd, event in events:
                if fd == clean_stderr:
                    proc.terminate()
                    msg = "stderr disconnect while virt-install was running"
                    reply(False, msg)
                    logger.error(msg)
                    destroy_if_requested()
                    return
                elif fd == m_fd:
                    if event & select.POLLIN:
                        try:
                            data = os.read(m_fd, 1024)
                        except OSError as e:
                            if e.errno == errno.EIO:  # child exited
                                break
                            raise
                        if data == b"":
                            break
                        while data != b"":
                            written = os.write(clean_stderr, data)
                            data = data[written:]
                    elif event & (select.POLLERR | select.POLLHUP):
                        break  # child exited (pipe broken)
            else:
                continue  # the inner 'for' didn't break
            break

    except BaseException as e:
        if proc is not None:
            proc.terminate()
        reply(False, f"{type(e).__name__}({e})")
        logger.exception("unknown error")
        destroy_if_requested()
        return
    finally:
        os.close(m_fd)

    code = proc.wait()
    if code != 0:
        destroy_if_requested()
        reply(False, f"virt-install exited with {code}")
    else:
        reply(True)


def cmd_create_volume(logger, reply, args):
    pool = args["pool"]
    name = args["name"]
    size = int(args["size"])
    fmt = args["format"]
    remove_existing = bool(args["remove_existing"]) if "remove_existing" in args else False

    if fmt not in ("raw", "qcow2"):
        reply(False, f"invalid format: {fmt}")
        return

    # look up the pool, ensure it's type=dir
    code, output = virsh("pool-dumpxml", pool)
    if code != 0:
        reply(False, f"failed getting pool '{pool}': {output}")
        return
    xml_root = ET.fromstring(output)
    if xml_root.get("type") != "dir":
        reply(False, "only type=dir storage pools are supported")
        return

    # find the on-disk pool directory
    pool_dir = None
    if (target := xml_root.find("target")) is not None:
        if (path := target.find("path")) is not None:
            pool_dir = path.text
    if not pool_dir:
        reply(False, f"could not find on-disk location of pool {pool}")
        return

    vol_path = Path(pool_dir) / name
    if vol_path.exists():
        if remove_existing:
            vol_path.unlink()
        else:
            reply(False, f"'{name}' already exists: {str(vol_path)}")
            return

    # create the image
    if fmt == "raw":
        # always use sparse
        vol_path.touch(mode=0o600)
        os.truncate(vol_path, size)
    else:
        # always use preallocation=metadata (has a small size impact)
        proc = subprocess.run(
            (
                "qemu-img", "create", "-f", "qcow2", "-o", "preallocation=metadata",
                vol_path, str(size),
            ),
            stdin=subprocess.DEVNULL,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            timeout=3600,
        )
        if proc.returncode != 0:
            msg = f"failed 'qemu-img': {proc.stdout}"
            reply(False, msg)
            logger.error(msg)
            return

    # refresh the storage pool to register the new volume
    # - try a few times to account for race conditions in type=dir pools
    for _ in range(10):
        code, output = virsh("pool-refresh", pool)
        if code == 0:
            break
        if "has asynchronous jobs running" not in output:
            reply(False, f"failed refreshing pool '{pool}': {output}")
            return
        time.sleep(0.1)
    else:
        reply(False, f"failed refreshing pool '{pool}': {output}")
        return

    reply(True)


def cmd_copy_volume(logger, reply, args):
    pool = args["pool"]

    # look up the pool, ensure it's type=dir
    code, output = virsh("pool-dumpxml", pool)
    if code != 0:
        reply(False, f"failed getting pool '{pool}': {output}")
        return
    xml_root = ET.fromstring(output)
    if xml_root.get("type") != "dir":
        reply(False, "only type=dir storage pools are supported")
        return

    # translate domain to volume
    nvram_path = None
    if "to_domain" in args:
        to_domain = args["to_domain"]
        code, output = virsh("dumpxml", to_domain)
        if code != 0:
            reply(False, f"failed getting domain '{to_domain}': {output}")
            return
        xml_root = ET.fromstring(output)

        # find a storage disk attached to the specified pool
        if (xml_devices := xml_root.find("devices")) is not None:
            for xml_disk in xml_devices.findall("disk"):
                if xml_disk.get("type") != "volume":
                    continue
                xml_disk_source = xml_disk.find("source")
                if xml_disk_source is None:
                    continue
                if xml_disk_source.get("pool") != pool:
                    continue
                to_vol = xml_disk_source.get("volume")
                logger.debug(f"found volume for '{to_domain}': {to_vol}")
                break
            else:
                reply(False, f"domain '{to_domain}' has no volume in pool '{pool}'")
                return
        else:
            reply(False, f"domain '{to_domain}' has no '<devices>'")
            return

        # find any specified nvram file path
        if (xml_os := xml_root.find("os")) is not None:
            if (xml_nvram := xml_os.find("nvram")) is not None:
                if xml_nvram.text:
                    logger.debug(f"found nvram for '{to_domain}': {xml_nvram.text}")
                    nvram_path = Path(xml_nvram.text)
    else:
        to_vol = args["to"]

    # get actual file path for the source volume,
    # reuse the parent directory of its file for the destination one
    # (we could look up pool dir path, but this is safer for atomic moves)
    def vol_to_path(vol_name):
        code, output = virsh("vol-dumpxml", vol_name, pool)
        if code != 0:
            reply(False, f"failed getting '{vol_name}': {output}")
            return None
        xml_root = ET.fromstring(output)
        if (xml_target := xml_root.find("target")) is not None:
            if (xml_path := xml_target.find("path")) is not None:
                if xml_path.text:
                    return Path(xml_path.text)
        return None

    from_vol = args["from"]
    if (from_path := vol_to_path(from_vol)) is None:
        return
    to_path = from_path.parent / to_vol

    if args.get("move"):
        # just move (rename) the image
        from_path.replace(to_path.absolute())
    else:
        # copy (reflink) the source image to the destination
        # - use 'cp' instead of shutil.copy* because shutil is super slow
        #   for files with holes / causes tons of IO
        logger.info(f"copying '{from_vol}' ({from_path}) -> '{to_vol}' ({to_path})")
        fd, tmpf = tempfile.mkstemp(dir=to_path.parent.absolute())
        os.close(fd)
        proc = subprocess.run(
            ("cp", "-f", "--reflink=auto", from_path.absolute(), tmpf),
            stdin=subprocess.DEVNULL,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            timeout=3600,
        )
        if proc.returncode != 0:
            msg = f"failed 'cp': {proc.stdout}"
            reply(False, msg)
            logger.error(msg)
            Path(tmpf).unlink()
            return
        Path(tmpf).replace(to_path.absolute())

    # if the domain uses UEFI/SecureBoot, try to remove its nvram image,
    # freeing any previous OS installation metadata
    if nvram_path:
        logger.info(f"removing nvram for '{to_domain}': {nvram_path}")
        nvram_path.unlink(missing_ok=True)

    # refresh the storage pool to register the new volume
    # - try a few times to account for race conditions in type=dir pools
    for _ in range(10):
        code, output = virsh("pool-refresh", pool)
        if code == 0:
            break
        if "has asynchronous jobs running" not in output:
            reply(False, f"failed refreshing pool '{pool}': {output}")
            return
        time.sleep(0.1)
    else:
        reply(False, f"failed refreshing pool '{pool}': {output}")
        return

    reply(True)


class SilentError(Exception):
    pass


def reply(success, msg=None, cmd=None, /, **kwargs):
    response = {"success": success}
    if cmd:
        response["cmd"] = cmd
    response |= kwargs
    if msg is not None:
        response["reply"] = msg
    logger = logging.getLogger(cmd) if cmd else logging
    logger.debug(f"sending reply: {response}")
    with os.fdopen(clean_stdout, "w", closefd=False) as stdout_fobj:
        json.dump(response, stdout_fobj)
        stdout_fobj.write("\n")


def arg_error(msg):
    logging.warning(msg)
    reply(False, msg)
    raise SilentError from None


def handle_cmd(cmd, args):
    def pop_args(*names, optional=False):
        for name in names:
            try:
                yield (name, args.pop(name))
            except KeyError:
                if optional:
                    continue
                else:
                    arg_error(f"missing argument for cmd '{cmd}': {name}")

    if cmd == "ping":
        valid_args = {}
        func = cmd_ping

    elif cmd == "setname":
        valid_args = dict(pop_args("name"))
        func = cmd_setname

    elif cmd == "virsh":
        valid_args = dict(pop_args("args"))
        func = cmd_virsh

    elif cmd == "reserve":
        valid_args = dict(pop_args("filter", optional=True))
        func = cmd_reserve

    elif cmd == "release":
        valid_args = dict(pop_args("domain"))
        func = cmd_release

    elif cmd == "reservations":
        valid_args = {}
        func = cmd_reservations

    elif cmd == "upload":
        valid_args = dict(pop_args("name", "length"))
        func = cmd_upload

    elif cmd == "virt-install":
        valid_args = dict(pop_args("args"))
        valid_args |= pop_args("destroy_on_error", optional=True)
        func = cmd_virt_install

    elif cmd == "create-volume":
        valid_args = dict(pop_args("pool", "name", "size", "format"))
        valid_args |= pop_args("remove_existing", optional=True)
        func = cmd_create_volume

    elif cmd == "copy-volume":
        valid_args = dict(pop_args("pool", "from"))
        valid_args |= pop_args("to", "to_domain", "move", optional=True)
        if "to" not in valid_args and "to_domain" not in valid_args:
            arg_error(f"one of 'to' or 'to_domain' needed for '{cmd}'")
        elif "to" in valid_args and "to_domain" in valid_args:
            arg_error(f"only one of 'to' or 'to_domain' allowed for '{cmd}'")
        func = cmd_copy_volume

    else:
        reply(False, f"invalid cmd: '{cmd}'")
        return

    # check leftover args
    if args:
        msg = f"leftover arguments for cmd '{cmd}': {args}"
        logging.warning(msg)
        reply(False, msg)
        return

    logging.info(f"running '{cmd}': {valid_args}")

    logger_for_cmd = logging.getLogger(cmd)

    def reply_for_cmd(success, msg=None, /, **kwargs):
        # use a per-cmd-specific logger and
        # pre-load the reply() JSON with valid user-passed arguments
        return reply(success, msg, cmd, **valid_args, **kwargs)

    return func(logger_for_cmd, reply_for_cmd, valid_args)


def loop():
    while line := sys.stdin.buffer.readline(1048576):
        logging.debug(f"received line: {line}")

        try:
            request = json.loads(line)
        except json.decoder.JSONDecodeError as e:
            reply(False, f"{type(e).__name__}({e})")
            continue

        if "cmd" not in request:
            reply(False, "missing 'cmd' in the JSON request")
            continue

        cmd = request.pop("cmd")
        args = request
        try:
            handle_cmd(cmd, args)
        except SilentError:
            continue
        except BaseException as e:
            logging.exception(f"{type(e)} while handling cmd '{cmd}'")
            reply(False, f"{type(e).__name__}({e}) while handling cmd '{cmd}'")
            continue


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--log", "-L", choices=("debug", "info", "warning", "error"),
        help="write to syslog, >= this log level",
    )
    parser.add_argument(
        "libvirt_uri",
        help="libvirt URI to connect to (ie. qemu:///system)",
        nargs="?",
    )
    args = parser.parse_args()

    global libvirt_uri
    libvirt_uri = args.libvirt_uri

    # set up logging
    if args.log:
        if args.log == "debug":
            level = logging.DEBUG
        elif args.log == "info":
            level = logging.INFO
        elif args.log == "warning":
            level = logging.WARNING
        else:
            level = logging.ERROR

        tag = Path(sys.argv[0]).stem

        logging.basicConfig(
            level=level,
            format=f"{tag}: %(name)s: %(message)s",
            handlers=(SysLogHandler(address="/dev/log"),),
            force=True,  # only syslog
        )
    else:
        logging.basicConfig(
            # make sure logging calls are no-op
            level=logging.CRITICAL,
            handlers=(logging.NullHandler(),),
            force=True,
        )

    logging.info("starting up")
    atexit.register(logging.info, "shutting down")

    try:
        # move stdout/stderr to fds beyond 2 to prevent python itself
        # from polluting them and corrupting our protocol
        # (exceptions, the warnings module, etc.)
        global clean_stdout, clean_stderr
        clean_stdout = os.dup(1)
        clean_stderr = os.dup(2)
        fd = os.open(os.devnull, os.O_RDWR)
        os.dup2(fd, 1)
        os.dup2(fd, 2)
        os.close(fd)

        # create state dir
        global state_dir
        home_dir = Path.home()
        if os.geteuid() != 0 and home_dir.is_dir():
            state_dir = home_dir / ".atex-virt-helper"
        else:
            state_dir = Path("/var/lib/atex-virt-helper")
        try:
            state_dir.mkdir(exist_ok=True)
            (state_dir / "locks").mkdir(exist_ok=True)
        except OSError as e:
            logging.error(f"failed to mkdir: {type(e).__name__}({e})")
            raise SystemExit(1) from None

        # run in a unique tmpdir for cmds like 'upload'
        with tempfile.TemporaryDirectory(dir="/var/tmp") as tmpdir:
            os.chdir(tmpdir)
            loop()
    except BaseException:
        logging.exception("unexpected condition")
        raise SystemExit(1) from None


if __name__ == "__main__":
    main()
