#!/usr/bin/env python3
"""Benchmark Treestamps cold-load, set, get, dumpf, and WAL replay."""

from __future__ import annotations

import argparse
import shutil
import statistics
import sys
import tempfile
import time
from pathlib import Path

from treestamps.grove import Grovestamps, GrovestampsConfig

PROGRAM_NAME = "bench"


def build_tree(root: Path, dirs: int, files_per_dir: int, depth: int) -> list[Path]:
    """Create a synthetic tree and return the list of file paths."""
    paths: list[Path] = []
    for d in range(dirs):
        sub = root
        for level in range(depth):
            sub = sub / f"d{level}_{d}"
        sub.mkdir(parents=True, exist_ok=True)
        for f in range(files_per_dir):
            p = sub / f"file_{f}.bin"
            p.touch()
            paths.append(p)
    return paths


def write_stamp_file(root: Path, paths: list[Path]) -> None:
    """Write a stamp file directly via Treestamps (used to seed cold-load)."""
    config = GrovestampsConfig(PROGRAM_NAME, paths=(root,))
    gs = Grovestamps(config)
    ts = gs[root]
    base = time.time()
    for i, p in enumerate(paths):
        ts.set(p, base + i)
    gs.dumpf()


def time_block(fn, repeats: int = 3) -> tuple[float, float, float]:
    """Run fn repeats times; return (min, median, max) seconds."""
    samples = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        fn()
        samples.append(time.perf_counter() - t0)
    return min(samples), statistics.median(samples), max(samples)


def fmt_secs(s: float) -> str:
    if s < 1e-3:
        return f"{s * 1e6:8.1f} us"
    if s < 1.0:
        return f"{s * 1e3:8.2f} ms"
    return f"{s:8.3f} s"


def fmt_rate(n: int, s: float) -> str:
    if s <= 0:
        return "       inf ops/s"
    return f"{n / s:10.0f} ops/s"


def bench(args: argparse.Namespace) -> None:
    work = Path(tempfile.mkdtemp(prefix="treestamps-bench-"))
    try:
        root = work / "tree"
        root.mkdir()

        n_paths = args.dirs * args.files
        print(
            f"Building tree: {args.dirs} dirs x {args.files} files x depth"
            f" {args.depth} = {n_paths} files at {root}"
        )
        t0 = time.perf_counter()
        paths = build_tree(root, args.dirs, args.files, args.depth)
        print(f"  build:       {fmt_secs(time.perf_counter() - t0)}")

        # --- Cold load -------------------------------------------------------
        print("\nSeeding stamp file ...")
        write_stamp_file(root, paths)

        print(
            f"\n[cold load] Grovestamps() over {n_paths}-entry stamp file"
            f"  ({args.repeats} runs)"
        )

        def cold_load() -> None:
            Grovestamps(GrovestampsConfig(PROGRAM_NAME, paths=(root,)))

        lo, med, hi = time_block(cold_load, args.repeats)
        print(f"  min/med/max: {fmt_secs(lo)} / {fmt_secs(med)} / {fmt_secs(hi)}")

        # --- Cold load with ignore globs ------------------------------------
        # Stresses _is_path_skipped, which runs once per directory entry.
        ignore_globs = ("*.tmp", "__pycache__", "*.pyc", ".git", "node_modules")
        print(
            f"\n[cold load + ignore] {len(ignore_globs)} ignore globs"
            f"  ({args.repeats} runs)"
        )

        def cold_load_ignore() -> None:
            Grovestamps(
                GrovestampsConfig(
                    PROGRAM_NAME,
                    paths=(root,),
                    ignore=ignore_globs,
                    check_config=False,
                )
            )

        lo, med, hi = time_block(cold_load_ignore, args.repeats)
        print(f"  min/med/max: {fmt_secs(lo)} / {fmt_secs(med)} / {fmt_secs(hi)}")

        # --- set() throughput ------------------------------------------------
        # Fresh tree under root2 to isolate from cold-load measurements
        root2 = work / "tree2"
        shutil.copytree(root, root2)
        for stale in root2.glob(f".{PROGRAM_NAME}_treestamps*"):
            stale.unlink()
        paths2 = sorted(root2.rglob("file_*.bin"))

        print(f"\n[set]      {n_paths} sets")
        gs2 = Grovestamps(GrovestampsConfig(PROGRAM_NAME, paths=(root2,)))
        ts2 = gs2[root2]
        base = time.time()
        t0 = time.perf_counter()
        for i, p in enumerate(paths2):
            ts2.set(p, base + i)
        elapsed = time.perf_counter() - t0
        print(f"  total:       {fmt_secs(elapsed)}   {fmt_rate(n_paths, elapsed)}")

        # --- get() throughput (after sets) -----------------------------------
        print(f"\n[get]      {n_paths} gets")
        t0 = time.perf_counter()
        for p in paths2:
            ts2.get(p)
        elapsed = time.perf_counter() - t0
        print(f"  total:       {fmt_secs(elapsed)}   {fmt_rate(n_paths, elapsed)}")

        # --- dumpf() ---------------------------------------------------------
        print(f"\n[dumpf]    after {n_paths} sets")
        t0 = time.perf_counter()
        ts2.dumpf()
        elapsed = time.perf_counter() - t0
        print(f"  total:       {fmt_secs(elapsed)}")

        # --- WAL replay ------------------------------------------------------
        # Fresh tree, do N sets, do NOT dumpf, then time cold load.
        root3 = work / "tree3"
        shutil.copytree(root, root3)
        for stale in root3.glob(f".{PROGRAM_NAME}_treestamps*"):
            stale.unlink()
        paths3 = sorted(root3.rglob("file_*.bin"))

        gs3 = Grovestamps(GrovestampsConfig(PROGRAM_NAME, paths=(root3,)))
        ts3 = gs3[root3]
        base = time.time()
        for i, p in enumerate(paths3):
            ts3.set(p, base + i)
        if ts3._wal:
            ts3._wal.flush()
        # Clear refs so the WAL file can be reopened cleanly
        del ts3, gs3

        print(f"\n[wal load] cold Grovestamps() reading WAL with {n_paths} entries")

        def wal_load() -> None:
            Grovestamps(GrovestampsConfig(PROGRAM_NAME, paths=(root3,)))

        lo, med, hi = time_block(wal_load, args.repeats)
        print(f"  min/med/max: {fmt_secs(lo)} / {fmt_secs(med)} / {fmt_secs(hi)}")

    finally:
        shutil.rmtree(work, ignore_errors=True)


def main() -> int:
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--dirs", type=int, default=100, help="number of leaf directories")
    p.add_argument(
        "--files", type=int, default=1000, help="files per leaf directory"
    )
    p.add_argument(
        "--depth", type=int, default=3, help="directory nesting depth (>=1)"
    )
    p.add_argument(
        "--repeats", type=int, default=3, help="repeats for cold-load measurements"
    )
    args = p.parse_args()
    bench(args)
    return 0


if __name__ == "__main__":
    sys.exit(main())
