#!/usr/bin/env python3
"""bty live env: flash-on-boot.

Reads ``bty.*`` parameters from ``/proc/cmdline``, fetches the
assigned image from the bty server, runs ``bty flash``, signals
completion, and reboots.

Driven by ``bty-flash-on-boot.service`` after ``network-online.target``.
If no ``bty.*`` parameters are present in /proc/cmdline (e.g. the
operator is booting the live env interactively for debugging), the
service exits 0 and the live env drops to its console.

Recognised cmdline parameters:
- ``bty.server=URL``              - base URL of the bty-web server
- ``bty.mac=MAC``                 - this machine's MAC (xx-xx-...)
- ``bty.image_url=URL``           - direct URL to the image to flash
- ``bty.target_disk_serial=...``  - serial number of the disk to
                                    flash. REQUIRED. Matched against
                                    ``lsblk -d -o SERIAL`` at boot
                                    time. If the serial is missing
                                    or doesn't match any present
                                    disk, the script bails rather
                                    than picking the wrong drive.
                                    Operator picks via /ui/machines/
                                    {mac}; the server bakes it into
                                    the iPXE flash chain via
                                    ``ipxe_flash.j2``.
- ``bty.mode=interactive``        - short-circuit; ``bty-tui-on-tty1.service``
                                    handles the interactive flash flow on
                                    tty1 instead. Set by the server's
                                    ``ipxe_tui.j2`` for ``boot_policy=tui``
                                    (default for unknown MACs).
"""

from __future__ import annotations

import json
import shlex
import subprocess
import sys
import urllib.parse
import urllib.request
from pathlib import Path

CMDLINE = Path("/proc/cmdline")
LOCAL_IMAGE_DIR = Path("/var/tmp")

REQUIRED = ("bty.server", "bty.mac", "bty.image_url", "bty.target_disk_serial")


def local_image_path(image_url: str) -> Path:
    """Pick a local cache path that preserves the URL's filename.

    ``bty.images.detect_format()`` keys off the file extension
    (.qcow2, .img, .img.{zst,xz,gz,bz2}) so the on-disk name has
    to keep it -- a fixed ``bty-flash-on-boot.image`` would
    always fail validation.
    """
    parsed_name = Path(urllib.parse.urlparse(image_url).path).name
    if not parsed_name:
        parsed_name = "bty-flash-on-boot.img"
    return LOCAL_IMAGE_DIR / parsed_name


def cmdline_args() -> dict[str, str]:
    """Parse /proc/cmdline into a dict of ``bty.*`` -> value tokens."""
    raw = CMDLINE.read_text()
    out: dict[str, str] = {}
    for token in shlex.split(raw):
        if token.startswith("bty.") and "=" in token:
            k, _, v = token.partition("=")
            out[k] = v
    return out


def pick_target(target_serial: str) -> str:
    """Resolve the disk whose SERIAL matches ``target_serial``.

    Serial (vs /dev/sda) is the durable identifier baked into the
    iPXE flash chain by bty-web. ``/dev/sda`` can flip across
    kernel versions and udev rules but the disk's serial number
    is fixed. Matching on serial means a swapped-out or
    renumbered drive triggers a clean refusal here instead of
    wiping the wrong target.

    Shells out to ``lsblk -d -e7 -J`` (``-d`` strips partitions,
    ``-e7`` excludes loop devices). ``-o`` adds SERIAL so we can
    compare directly; lsblk's structured JSON is what bty's own
    ``list disks`` codepath consumes too.

    Raises :class:`SystemExit` with an operator-readable message
    when no disk matches. The serial-list-on-failure helps an
    operator who can see tty1 (KVM / IPMI / serial console)
    figure out which disk to pick in /ui/machines/{mac}.
    """
    out = subprocess.run(
        ["lsblk", "-d", "-e7", "-J", "-o", "PATH,SERIAL,RM,RO,TYPE,MODEL,SIZE"],
        check=True,
        capture_output=True,
        text=True,
    )
    payload = json.loads(out.stdout)
    candidates = []
    for disk in payload.get("blockdevices", []):
        if disk.get("type") != "disk":
            continue
        path = disk.get("path")
        # Mirror bty.disks.list_disks's whitespace-strip on serial
        # so the inventory side (which records the stripped value)
        # and the flash side (this lookup) agree on the canonical
        # form. Some USB enclosures / vendor firmware ship
        # serials with trailing whitespace.
        raw_serial = disk.get("serial")
        serial = raw_serial.strip() if isinstance(raw_serial, str) else raw_serial
        if path:
            candidates.append((path, serial, disk.get("model"), disk.get("size")))
            if serial == target_serial:
                return str(path)
    rendered = ", ".join(
        f"{p} (serial={s!r}, model={m!r}, size={sz!r})" for p, s, m, sz in candidates
    )
    raise SystemExit(
        f"bty-flash-on-boot: no disk on this host has "
        f"serial={target_serial!r}. Current disks: "
        f"{rendered or '(none)'}. The operator's pick in /ui/machines "
        "is stale; re-pick after the next inventory."
    )


def download(url: str, dest: Path) -> None:
    """Stream ``url`` into ``dest``.

    Atomic via a sibling ``.partial`` file + rename so a torn
    download (server hangup, OOM kill, kernel panic) can't leave
    a half-written file where ``bty flash --yes`` would read it
    as a complete image. ``timeout=300`` on the HTTP read keeps
    the boot from hanging indefinitely on a misconfigured server.
    """
    print(f"bty-flash-on-boot: downloading {url} -> {dest}", flush=True)
    dest.parent.mkdir(parents=True, exist_ok=True)
    tmp = dest.with_suffix(dest.suffix + ".partial")
    try:
        with urllib.request.urlopen(url, timeout=300) as resp, tmp.open("wb") as f:
            while True:
                chunk = resp.read(1 << 20)
                if not chunk:
                    break
                f.write(chunk)
        tmp.replace(dest)
    except BaseException:
        # Clean up the partial on any failure -- including
        # KeyboardInterrupt -- so a re-attempted boot starts fresh.
        try:
            tmp.unlink()
        except FileNotFoundError:
            pass
        raise


def signal_done(server: str, mac: str) -> None:
    """Best-effort completion signal. The endpoint lands in D-3."""
    url = f"{server.rstrip('/')}/pxe/{mac}/done"
    try:
        req = urllib.request.Request(url, data=b"", method="POST")
        with urllib.request.urlopen(req, timeout=10) as _:
            print(f"bty-flash-on-boot: completion signalled to {url}", flush=True)
    except Exception as exc:  # noqa: BLE001
        print(
            f"bty-flash-on-boot: completion signal to {url} failed: {exc}; "
            "continuing to reboot anyway",
            file=sys.stderr,
            flush=True,
        )


def main() -> int:
    args = cmdline_args()

    # Interactive mode (boot_policy=tui on the server side) -> defer
    # to bty-tui-on-tty1.service; that unit owns tty1 and runs
    # bty-tui --server URL --mac MAC for the operator. Trying to
    # flash here would race the TUI session.
    if args.get("bty.mode") == "interactive":
        print(
            "bty-flash-on-boot: bty.mode=interactive on cmdline; "
            "bty-tui-on-tty1.service handles this boot",
            flush=True,
        )
        return 0

    missing = [k for k in REQUIRED if k not in args]
    if missing:
        print(
            "bty-flash-on-boot: cmdline missing required keys "
            f"({', '.join(missing)}); not flashing - dropping to console",
            file=sys.stderr,
            flush=True,
        )
        return 0  # not an error, just no work

    server = args["bty.server"].rstrip("/")
    mac = args["bty.mac"]
    image_url = args["bty.image_url"]
    target_serial = args["bty.target_disk_serial"]

    print(
        f"bty-flash-on-boot: server={server} mac={mac} "
        f"target_disk_serial={target_serial}",
        flush=True,
    )

    local_image = local_image_path(image_url)
    download(image_url, local_image)
    target = pick_target(target_serial)
    print(f"bty-flash-on-boot: target disk {target} (serial {target_serial})", flush=True)

    subprocess.run(
        ["bty", "flash", str(local_image), target, "--yes"],
        check=True,
    )

    signal_done(server, mac)

    print("bty-flash-on-boot: flash complete; rebooting in 5s", flush=True)
    subprocess.run(["sleep", "5"], check=False)
    # ``check=True`` so a failing ``systemctl reboot`` (no-systemd,
    # permission-denied, transition-blocked, ...) crashes the script
    # loudly instead of leaving the operator with a "flash complete;
    # rebooting in 5s" marker on the serial console followed by
    # nothing happening forever.
    subprocess.run(["systemctl", "reboot"], check=True)
    return 0


if __name__ == "__main__":
    sys.exit(main())
