#!/usr/bin/env python3
"""bty live env: flash-on-boot.

Reads ``bty.*`` parameters from ``/proc/cmdline``, fetches the
assigned image from the bty server, runs ``bty flash``, signals
completion, and reboots.

Driven by ``bty-flash-on-boot.service`` after ``network-online.target``.
If no ``bty.*`` parameters are present in /proc/cmdline (e.g. the
operator is booting the live env interactively for debugging), the
service exits 0 and the live env drops to its console.

Recognised cmdline parameters:
- ``bty.server=URL``              - base URL of the bty-web server
- ``bty.mac=MAC``                 - this machine's MAC (xx-xx-...)
- ``bty.image_url=URL``           - direct URL to the image to flash
- ``bty.target_disk_serial=...``  - serial number of the disk to
                                    flash. REQUIRED. Matched against
                                    ``lsblk -d -o SERIAL`` at boot
                                    time. If the serial is missing
                                    or doesn't match any present
                                    disk, the script bails rather
                                    than picking the wrong drive.
                                    Operator picks via /ui/machines/
                                    {mac}; the server bakes it into
                                    the iPXE flash chain via
                                    ``ipxe_flash.j2``.
- ``bty.mode=interactive``        - short-circuit; ``bty-tui-on-tty1.service``
                                    handles the interactive flash flow on
                                    tty1 instead. Set by the server's
                                    ``ipxe_tui.j2`` for ``boot_policy=tui``
                                    (default for unknown MACs).
"""

from __future__ import annotations

import json
import shlex
import subprocess
import sys
import urllib.parse
import urllib.request
from pathlib import Path

CMDLINE = Path("/proc/cmdline")
LOCAL_IMAGE_DIR = Path("/var/tmp")

REQUIRED = ("bty.server", "bty.mac", "bty.image_url", "bty.target_disk_serial")


def say(msg: str, *, err: bool = False) -> None:
    """Print to stdout/stderr (which the systemd unit routes to
    journal+console -- typically the serial console when the
    cmdline carries ``console=ttyS0``) AND mirror to /dev/tty1
    so the operator on the local framebuffer sees progress
    regardless of /dev/console routing.

    Without the /dev/tty1 mirror, the operator on a target with
    only a framebuffer console (no serial attached, or IPMI
    serial not connected) sees agetty's empty login prompt
    blinking while the flash happens silently in the background.
    """
    target = sys.stderr if err else sys.stdout
    print(msg, file=target, flush=True)
    # Best-effort tty1 mirror. Failure modes (no /dev/tty1, EIO,
    # etc.) must not abort the flash -- the journal entry above
    # is the authoritative log.
    try:
        with open("/dev/tty1", "w") as tty:
            tty.write(msg + "\n")
    except OSError:
        pass


_FORMAT_EXTENSIONS = (".img.zst", ".img.xz", ".img.gz", ".img.bz2", ".qcow2", ".img")


def local_image_path(image_url: str, image_format: str | None = None) -> Path:
    """Pick a local cache path that ends in a recognised image-
    format extension.

    ``bty.images.detect_format()`` keys off the file extension
    (.qcow2, .img, .img.{zst,xz,gz,bz2}) so the on-disk name has
    to keep it. The URL filename usually carries the extension --
    but bty-web's ``/images/<sha>/<display-name>`` route emits
    URLs whose trailing segment is the catalog entry's display
    name, which is human text and may have no extension at all
    (``nosi fedora-sysdev (x86_64, rolling)``).

    Resolution order:

      1. URL path filename (URL-decoded). If it already ends in
         a recognised extension, keep it as-is.
      2. URL path filename + ``.<image_format>`` suffix when
         ``image_format`` is supplied via ``bty.image_format=`` on
         the cmdline. Catches the human-named-catalog case.
      3. ``bty-flash-on-boot.img`` fallback.
    """
    raw_name = Path(urllib.parse.urlparse(image_url).path).name
    decoded = urllib.parse.unquote(raw_name) if raw_name else "bty-flash-on-boot.img"
    lower = decoded.lower()
    if any(lower.endswith(ext) for ext in _FORMAT_EXTENSIONS):
        return LOCAL_IMAGE_DIR / decoded
    if image_format:
        return LOCAL_IMAGE_DIR / f"{decoded}.{image_format}"
    return LOCAL_IMAGE_DIR / decoded


def cmdline_args() -> dict[str, str]:
    """Parse /proc/cmdline into a dict of ``bty.*`` -> value tokens."""
    raw = CMDLINE.read_text()
    out: dict[str, str] = {}
    for token in shlex.split(raw):
        if token.startswith("bty.") and "=" in token:
            k, _, v = token.partition("=")
            out[k] = v
    return out


def pick_target(target_serial: str) -> str:
    """Resolve the disk whose SERIAL matches ``target_serial``.

    Serial (vs /dev/sda) is the durable identifier baked into the
    iPXE flash chain by bty-web. ``/dev/sda`` can flip across
    kernel versions and udev rules but the disk's serial number
    is fixed. Matching on serial means a swapped-out or
    renumbered drive triggers a clean refusal here instead of
    wiping the wrong target.

    Shells out to ``lsblk -d -e7 -J`` (``-d`` strips partitions,
    ``-e7`` excludes loop devices). ``-o`` adds SERIAL so we can
    compare directly; lsblk's structured JSON is what bty's own
    ``list disks`` codepath consumes too.

    Raises :class:`SystemExit` with an operator-readable message
    when no disk matches. The serial-list-on-failure helps an
    operator who can see tty1 (KVM / IPMI / serial console)
    figure out which disk to pick in /ui/machines/{mac}.
    """
    out = subprocess.run(
        ["lsblk", "-d", "-e7", "-J", "-o", "PATH,SERIAL,RM,RO,TYPE,MODEL,SIZE"],
        check=True,
        capture_output=True,
        text=True,
    )
    payload = json.loads(out.stdout)
    candidates = []
    for disk in payload.get("blockdevices", []):
        if disk.get("type") != "disk":
            continue
        path = disk.get("path")
        # Mirror bty.disks.list_disks's whitespace-strip on serial
        # so the inventory side (which records the stripped value)
        # and the flash side (this lookup) agree on the canonical
        # form. Some USB enclosures / vendor firmware ship
        # serials with trailing whitespace.
        raw_serial = disk.get("serial")
        serial = raw_serial.strip() if isinstance(raw_serial, str) else raw_serial
        if path:
            candidates.append((path, serial, disk.get("model"), disk.get("size")))
            if serial == target_serial:
                return str(path)
    rendered = ", ".join(
        f"{p} (serial={s!r}, model={m!r}, size={sz!r})" for p, s, m, sz in candidates
    )
    raise SystemExit(
        f"bty-flash-on-boot: no disk on this host has "
        f"serial={target_serial!r}. Current disks: "
        f"{rendered or '(none)'}. The operator's pick in /ui/machines "
        "is stale; re-pick after the next inventory."
    )


def download(url: str, dest: Path) -> None:
    """Stream ``url`` into ``dest``.

    Atomic via a sibling ``.partial`` file + rename so a torn
    download (server hangup, OOM kill, kernel panic) can't leave
    a half-written file where ``bty flash --yes`` would read it
    as a complete image. ``timeout=300`` on the HTTP read keeps
    the boot from hanging indefinitely on a misconfigured server.
    """
    say(f"bty-flash-on-boot: downloading {url} -> {dest}")
    dest.parent.mkdir(parents=True, exist_ok=True)
    tmp = dest.with_suffix(dest.suffix + ".partial")
    try:
        with urllib.request.urlopen(url, timeout=300) as resp, tmp.open("wb") as f:
            while True:
                chunk = resp.read(1 << 20)
                if not chunk:
                    break
                f.write(chunk)
        tmp.replace(dest)
    except BaseException:
        # Clean up the partial on any failure -- including
        # KeyboardInterrupt -- so a re-attempted boot starts fresh.
        try:
            tmp.unlink()
        except FileNotFoundError:
            pass
        raise


def signal_done(server: str, mac: str) -> None:
    """Best-effort completion signal. The endpoint lands in D-3."""
    url = f"{server.rstrip('/')}/pxe/{mac}/done"
    try:
        req = urllib.request.Request(url, data=b"", method="POST")
        with urllib.request.urlopen(req, timeout=10) as _:
            say(f"bty-flash-on-boot: completion signalled to {url}")
    except Exception as exc:  # noqa: BLE001
        say(
            f"bty-flash-on-boot: completion signal to {url} failed: {exc}; "
            "continuing to reboot anyway",
            err=True,
        )


def main() -> int:
    # Clear tty1 once at start so our progress messages land on a
    # blank canvas instead of stacking on top of whatever agetty
    # already painted. ``\033c`` is a full terminal reset (works
    # on both VT and serial); rendering it on /dev/tty1 directly
    # so the systemd journal isn't polluted with escape sequences.
    try:
        with open("/dev/tty1", "w") as tty:
            tty.write("\033c\n")
            tty.write("  bty live env -- network flash mode\n\n")
    except OSError:
        pass

    args = cmdline_args()

    # Interactive mode (boot_policy=tui on the server side) -> defer
    # to bty-tui-on-tty1.service; that unit owns tty1 and runs
    # bty-tui --server URL --mac MAC for the operator. Trying to
    # flash here would race the TUI session.
    if args.get("bty.mode") == "interactive":
        say(
            "bty-flash-on-boot: bty.mode=interactive on cmdline; "
            "bty-tui-on-tty1.service handles this boot"
        )
        return 0

    missing = [k for k in REQUIRED if k not in args]
    if missing:
        say(
            "bty-flash-on-boot: cmdline missing required keys "
            f"({', '.join(missing)}); not flashing - dropping to console",
            err=True,
        )
        return 0  # not an error, just no work

    server = args["bty.server"].rstrip("/")
    mac = args["bty.mac"]
    image_url = args["bty.image_url"]
    target_serial = args["bty.target_disk_serial"]

    say(
        f"bty-flash-on-boot: server={server} mac={mac} "
        f"target_disk_serial={target_serial}"
    )

    local_image = local_image_path(image_url)
    download(image_url, local_image)
    target = pick_target(target_serial)
    say(f"bty-flash-on-boot: target disk {target} (serial {target_serial})")

    say("bty-flash-on-boot: flashing now (this is the long quiet part) ...")
    subprocess.run(
        ["bty", "flash", str(local_image), target, "--yes"],
        check=True,
    )

    signal_done(server, mac)

    say("bty-flash-on-boot: flash complete; rebooting in 5s")
    subprocess.run(["sleep", "5"], check=False)
    # ``check=True`` so a failing ``systemctl reboot`` (no-systemd,
    # permission-denied, transition-blocked, ...) crashes the script
    # loudly instead of leaving the operator with a "flash complete;
    # rebooting in 5s" marker on the serial console followed by
    # nothing happening forever.
    subprocess.run(["systemctl", "reboot"], check=True)
    return 0


if __name__ == "__main__":
    sys.exit(main())
