#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.9"
# dependencies = ["sqlite-utils"]
# ///
"""
brew-search: fast offline-first search of Homebrew's formula index.

Usage:
  brew-search <query> [--refresh] [-f/--formulae] [-c/--casks] [--limit N] [--json]
  brew-search -C/--cache                     Show cache status
  brew-search --stale 1h --fresh 24h <query> Custom TTLs (bkt-style)

Cache TTL (bkt-style):
  --stale <dur>   Age after which a background refresh is triggered (default: 6h)
  --fresh <dur>   Age after which a synchronous refresh is forced (default: never)

Cache location: ~/.cache/brew-search/
  brew-search.db  — SQLite with FTS5 index for fast search
  formula.json    — raw API response for re-creation
  cask.json       — raw API response for re-creation
"""
from __future__ import annotations

import argparse
import json
import re
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen

import sqlite_utils

# ── constants ────────────────────────────────────────────────────────────────

CACHE_DIR   = Path.home() / ".cache" / "brew-search"
DB_PATH     = CACHE_DIR / "brew-search.db"
FORMULA_URL = "https://formulae.brew.sh/api/formula.json"
CASK_URL    = "https://formulae.brew.sh/api/cask.json"
DEFAULT_STALE = 6 * 3600   # background refresh if cache older than 6 hours
TIMEOUT     = 10            # seconds

# ── colour helpers ────────────────────────────────────────────────────────────

USE_COLOR = sys.stdout.isatty()

def c(code: str, text: str) -> str:
    return f"\033[{code}m{text}\033[0m" if USE_COLOR else text

bold   = lambda t: c("1",     t)
dim    = lambda t: c("2",     t)
green  = lambda t: c("32",    t)
yellow = lambda t: c("33",    t)
cyan   = lambda t: c("36",    t)
red    = lambda t: c("31",    t)

# ── duration parsing ─────────────────────────────────────────────────────────

def parse_duration(s: str) -> int:
    """Parse a human duration string like '6h', '30m', '1d', '1h30m' into seconds."""
    s = s.strip().lower()
    total = 0
    for amount, unit in re.findall(r"(\d+)\s*([smhd])", s):
        n = int(amount)
        if unit == "s":
            total += n
        elif unit == "m":
            total += n * 60
        elif unit == "h":
            total += n * 3600
        elif unit == "d":
            total += n * 86400
    if total == 0:
        # Try bare number as seconds
        try:
            total = int(s)
        except ValueError:
            raise argparse.ArgumentTypeError(f"invalid duration: {s!r}  (examples: 30m, 6h, 1d, 1h30m)")
    return total

def fmt_duration(seconds: float) -> str:
    """Format seconds into a human-readable duration."""
    if seconds == float("inf"):
        return "never"
    s = int(seconds)
    if s < 60:
        return f"{s}s"
    if s < 3600:
        return f"{s // 60}m{s % 60}s" if s % 60 else f"{s // 60}m"
    h, rem = divmod(s, 3600)
    m = rem // 60
    if h < 24:
        return f"{h}h{m}m" if m else f"{h}h"
    d, h = divmod(h, 24)
    return f"{d}d{h}h" if h else f"{d}d"

# ── database helpers ─────────────────────────────────────────────────────────

def get_db() -> sqlite_utils.Database:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    return sqlite_utils.Database(DB_PATH)

def table_age(db: sqlite_utils.Database, kind: str) -> float:
    """Age in seconds since this kind was last imported."""
    if "_meta" not in db.table_names():
        return float("inf")
    try:
        row = db["_meta"].get(kind)
        return time.time() - row["updated_at"]
    except Exception:
        return float("inf")

def table_updated_at(db: sqlite_utils.Database, kind: str) -> float | None:
    """Return the timestamp when this kind was last imported, or None."""
    if "_meta" not in db.table_names():
        return None
    try:
        return db["_meta"].get(kind)["updated_at"]
    except Exception:
        return None

def table_count(db: sqlite_utils.Database, kind: str) -> int | None:
    if "_meta" not in db.table_names():
        return None
    try:
        return db["_meta"].get(kind)["count"]
    except Exception:
        return None

def table_exists(db: sqlite_utils.Database, kind: str) -> bool:
    return kind in db.table_names()

def json_path(kind: str) -> Path:
    return CACHE_DIR / f"{kind}.json"

def save_raw_json(kind: str, data: list[dict]) -> None:
    """Keep the raw API JSON alongside the DB for easy re-creation."""
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    tmp = json_path(kind).with_suffix(".tmp")
    with tmp.open("w") as f:
        json.dump(data, f)
    tmp.replace(json_path(kind))

def import_to_db(db: sqlite_utils.Database, kind: str, data: list[dict]) -> None:
    """Import fetched JSON into the SQLite DB with FTS index."""
    if kind == "formula":
        rows = [
            {
                "name": item.get("name", ""),
                "desc": item.get("desc") or "",
                "homepage": item.get("homepage", ""),
                "version": (item.get("versions") or {}).get("stable", ""),
                "raw": json.dumps(item),
            }
            for item in data
        ]
        pk = "name"
    else:  # cask
        rows = [
            {
                "token": item.get("token", ""),
                "name": item.get("name") or "",
                "desc": item.get("desc") or "",
                "homepage": item.get("homepage", ""),
                "version": str(item.get("version", "")),
                "raw": json.dumps(item),
            }
            for item in data
        ]
        pk = "token"

    # Drop old FTS first, then table, then re-create
    fts_name = f"{kind}_fts"
    if fts_name in db.table_names():
        db[fts_name].drop()
    if kind in db.table_names():
        db[kind].drop()

    db[kind].insert_all(rows, pk=pk)

    # Build FTS5 index
    if kind == "formula":
        db[kind].enable_fts(["name", "desc"], tokenize="porter", create_triggers=True)
    else:
        db[kind].enable_fts(["token", "name", "desc"], tokenize="porter", create_triggers=True)

    # Track when we last updated
    db["_meta"].insert(
        {"kind": kind, "updated_at": time.time(), "count": len(rows)},
        pk="kind",
        replace=True,
    )

# ── network ───────────────────────────────────────────────────────────────────

def fetch(url: str):
    req = Request(url, headers={"User-Agent": "brew-search-cli/1.0"})
    with urlopen(req, timeout=TIMEOUT) as r:
        return json.loads(r.read())

def refresh(kind: str, url: str, silent: bool = False) -> bool:
    """Fetch and import. Returns True on success."""
    if not silent:
        print(dim(f"  ↻ fetching {kind} index …"), file=sys.stderr)
    try:
        data = fetch(url)
        save_raw_json(kind, data)
        db = get_db()
        import_to_db(db, kind, data)
        if not silent:
            print(dim(f"  ✓ cached {len(data)} {kind}s"), file=sys.stderr)
        return True
    except (URLError, Exception) as e:
        if not silent:
            print(red(f"  ✗ fetch failed: {e}"), file=sys.stderr)
        return False

def background_refresh(kind: str, url: str) -> None:
    """Fork a silent background process to refresh the cache."""
    try:
        subprocess.Popen(
            [sys.executable, __file__, "--_bg-refresh", kind, url],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            start_new_session=True,
        )
    except Exception:
        pass  # best-effort

# ── cache status ──────────────────────────────────────────────────────────────

def show_cache_status() -> None:
    """Print cache status for all kinds."""
    db_exists = DB_PATH.exists()
    print(f"  {bold('cache dir')}   {CACHE_DIR}")
    print(f"  {bold('database')}    {DB_PATH}  {'exists' if db_exists else red('missing')}")
    if db_exists:
        size_mb = DB_PATH.stat().st_size / (1024 * 1024)
        print(f"  {bold('db size')}     {size_mb:.1f} MB")
    print()

    db = get_db() if db_exists else None

    for kind, url in [("formula", FORMULA_URL), ("cask", CASK_URL)]:
        label = green("formula") if kind == "formula" else yellow("cask")
        print(f"  {bold(label)}")

        # JSON file info
        jp = json_path(kind)
        if jp.exists():
            jp_size = jp.stat().st_size / (1024 * 1024)
            jp_mtime = datetime.fromtimestamp(jp.stat().st_mtime)
            jp_age = time.time() - jp.stat().st_mtime
            print(f"    json      {jp.name}  {jp_size:.1f} MB  {dim(jp_mtime.strftime('%Y-%m-%d %H:%M'))}  ({fmt_duration(jp_age)} ago)")
        else:
            print(f"    json      {dim('not cached')}")

        # DB table info
        if db and table_exists(db, kind):
            count = table_count(db, kind)
            updated = table_updated_at(db, kind)
            age = table_age(db, kind)
            ts_str = datetime.fromtimestamp(updated).strftime('%Y-%m-%d %H:%M') if updated else "?"
            count_str = f"{count} entries" if count else "?"
            print(f"    db index  {count_str}  {dim(ts_str)}  ({fmt_duration(age)} ago)")

            fts_name = f"{kind}_fts"
            has_fts = fts_name in db.table_names()
            print(f"    fts5      {'ready' if has_fts else red('missing')}")
        else:
            print(f"    db index  {dim('not built')}")
        print()

# ── search ────────────────────────────────────────────────────────────────────

def score(name: str, desc: str, terms: list[str]) -> int:
    """Return a relevance score (higher = better). 0 means no match."""
    name_l = name.lower()
    desc_l = desc.lower()
    total = 0
    for term in terms:
        t = term.lower()
        if t == name_l:
            total += 100
        elif name_l.startswith(t):
            total += 60
        elif t in name_l:
            total += 30
        elif t in desc_l:
            total += 10
        else:
            return 0   # all terms must match something
    return total

def fts_query(terms: list[str]) -> str:
    """Build an FTS5 query string: all terms must match (AND) with prefix matching."""
    escaped = []
    for t in terms:
        safe = t.replace('"', '""')
        escaped.append(f'"{safe}"*')
    return " AND ".join(escaped)

def search(db: sqlite_utils.Database, kind: str, query: str, limit: int) -> list[dict]:
    terms = query.split()
    if not terms:
        return []

    fts_table = f"{kind}_fts"

    # Try FTS5 first, fall back to full scan
    candidates = []
    if fts_table in db.table_names():
        fq = fts_query(terms)
        try:
            pk_col = "name" if kind == "formula" else "token"
            sql = f"""
                SELECT {kind}.raw FROM {kind}
                JOIN {fts_table} ON {kind}.{pk_col} = {fts_table}.{pk_col}
                WHERE {fts_table} MATCH ?
                LIMIT 200
            """
            candidates = [json.loads(row[0]) for row in db.execute(sql, [fq]).fetchall()]
        except Exception:
            pass

    if not candidates:
        # Fallback: load all and let Python score
        if kind in db.table_names():
            candidates = [json.loads(row[0]) for row in db.execute(f"SELECT raw FROM {kind}").fetchall()]

    # Score and rank
    scored = []
    for item in candidates:
        if kind == "formula":
            name = item.get("name", "")
        else:
            name = f"{item.get('token', '')} {item.get('name', '')}"
        desc = item.get("desc") or ""
        s = score(name, desc, terms)
        if s > 0:
            scored.append((s, item))

    scored.sort(key=lambda x: (-x[0], x[1].get("name") or x[1].get("token", "")))
    return [item for _, item in scored[:limit]]

# ── display ───────────────────────────────────────────────────────────────────

def fmt_formula(f: dict) -> str:
    name    = bold(green(f["name"]))
    ver     = dim(f.get("versions", {}).get("stable", ""))
    desc    = f.get("desc", "")
    homepage= dim(f.get("homepage", ""))
    line1 = f"{name}  {ver}"
    line2 = f"  {desc}"
    line3 = f"  {homepage}" if homepage else ""
    return "\n".join(x for x in [line1, line2, line3] if x)

def fmt_cask(f: dict) -> str:
    name    = bold(cyan(f.get("token", "")))
    ver     = dim(str(f.get("version", "")))
    desc    = f.get("desc", "")
    homepage= dim(f.get("homepage", ""))
    line1 = f"{name}  {ver}"
    line2 = f"  {desc}"
    line3 = f"  {homepage}" if homepage else ""
    return "\n".join(x for x in [line1, line2, line3] if x)

def display_section(results: list, kind: str) -> None:
    if not results:
        return
    label = yellow("casks") if kind == "cask" else green("formulae")
    fmt = fmt_cask if kind == "cask" else fmt_formula
    print(f"  {label}")
    print()
    for item in results:
        print(fmt(item))
        print()

# ── main ──────────────────────────────────────────────────────────────────────

def ensure_cache(db: sqlite_utils.Database, kind: str, url: str,
                 force: bool, stale: int, fresh: int | None) -> bool:
    """Ensure the DB table exists. Returns True if data is available."""
    needs_sync = force or not table_exists(db, kind)

    if not needs_sync and fresh is not None:
        age = table_age(db, kind)
        if age > fresh:
            needs_sync = True

    if needs_sync:
        ok = refresh(kind, url, silent=False)
        if not ok and not table_exists(get_db(), kind):
            return False
        return True

    age = table_age(db, kind)
    if age > stale:
        background_refresh(kind, url)
    return True

def main(argv=None):
    ap = argparse.ArgumentParser(
        prog="brew-search",
        description="Fast offline-first Homebrew formula/cask search.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    ap.add_argument("query", nargs="?", help="Search query")
    ap.add_argument("--refresh",  action="store_true", help="Force synchronous re-fetch before searching")
    ap.add_argument("-f", "--formulae", "--formula", action="store_true", help="Search formulae only")
    ap.add_argument("-c", "--casks",    "--cask",    action="store_true", help="Search casks only")
    ap.add_argument("-C", "--cache",    action="store_true", help="Show cache status and exit")
    ap.add_argument("--stale",    type=parse_duration, default=None, metavar="DUR",
                    help="Background refresh if cache older than DUR (default: 6h). Examples: 30m, 6h, 1d")
    ap.add_argument("--fresh",    type=parse_duration, default=None, metavar="DUR",
                    help="Force synchronous refresh if cache older than DUR (default: never)")
    ap.add_argument("-n", "--limit", type=int, default=20, metavar="N", help="Max results (default 20)")
    ap.add_argument("--json",     action="store_true", help="Output raw JSON")
    ap.add_argument("-g", "--grep", action="store_true", help="Greppable output: slug version url\\ndescription")
    # hidden flag used by background subprocess
    ap.add_argument("--_bg-refresh", nargs=2, metavar=("KIND", "URL"), help=argparse.SUPPRESS)

    args = ap.parse_args(argv)

    # ── background refresh mode (called by ourselves, silently) ──
    if getattr(args, "_bg_refresh", None):
        kind, url = args._bg_refresh
        refresh(kind, url, silent=True)
        return

    # ── cache status mode ──
    if args.cache:
        show_cache_status()
        return

    if not args.query:
        ap.print_help()
        sys.exit(0)

    stale = args.stale if args.stale is not None else DEFAULT_STALE
    fresh = args.fresh  # None means never force sync

    # Determine which kinds to search (both by default)
    if args.formulae and not args.casks:
        kinds = [("formula", FORMULA_URL)]
    elif args.casks and not args.formulae:
        kinds = [("cask", CASK_URL)]
    else:
        kinds = [("formula", FORMULA_URL), ("cask", CASK_URL)]

    # ── ensure caches exist ──
    db = get_db()
    for kind, url in kinds:
        if not ensure_cache(db, kind, url, args.refresh, stale, fresh):
            print(red(f"No cache for {kind} and fetch failed. Check your connection."), file=sys.stderr)
            sys.exit(1)

    # Re-open after potential refresh writes
    db = get_db()

    # ── search and display ──
    all_results = []
    for kind, url in kinds:
        age = table_age(db, kind)
        results = search(db, kind, args.query, args.limit)
        all_results.append((kind, results, age))

    if args.json:
        combined = {}
        for kind, results, _ in all_results:
            combined[kind] = results
        if len(all_results) == 1:
            combined = combined[all_results[0][0]]
        print(json.dumps(combined, indent=2))
        return

    if args.grep:
        for kind, results, _ in all_results:
            for item in results:
                slug = item.get("token") or item.get("name", "")
                ver = str(item.get("version", "")) if kind == "cask" else (item.get("versions") or {}).get("stable", "")
                url = item.get("homepage", "")
                desc = item.get("desc") or ""
                print(f"{slug}\t{ver}\t{url}")
                print(f"  {desc}")
        return

    # Cache age header
    ages = [age for _, _, age in all_results]
    min_age = min(ages) if ages else 0
    age_str = (
        "just fetched" if min_age < 60
        else fmt_duration(min_age) + " old"
    )
    searching = " + ".join(
        yellow("casks") if k == "cask" else green("formulae")
        for k, _, _ in all_results
    )
    print(dim(f"  cache: {age_str}   searching {searching}"))
    print()

    total = 0
    first_name = None
    for kind, results, _ in all_results:
        display_section(results, kind)
        total += len(results)
        if results and first_name is None:
            r = results[0]
            first_name = r.get("token") or r.get("name", "")

    if total == 0:
        print(dim(f"  no results for {args.query!r}"))
    elif first_name:
        print(dim(f"  {total} result(s)  •  brew install {first_name}"))


if __name__ == "__main__":
    main()
