"""Snakemake rules for the performance benchmark.

Included by the root benchmarks/Snakefile.
All paths are relative to the root Snakefile (benchmarks/).
"""

import os
import sys
from pathlib import Path

# Create log directory before any rule runs
onstart:
    os.makedirs("logs/performance", exist_ok=True)

# Make shared/ importable so we can read config constants here
_bench = Path(workflow.basedir)
sys.path.insert(0, str(_bench))

from shared.config import SEED  # noqa: E402

# Local config (performance-specific paths and parameters)
sys.path.insert(0, str(_bench / "performance"))
import importlib.util as _ilu

_spec = _ilu.spec_from_file_location("perf_config", _bench / "performance" / "config.py")
_pc = _ilu.module_from_spec(_spec)
_spec.loader.exec_module(_pc)

PERF_RESULTS = Path(_pc.RESULTS_DIR).resolve()
PERF_FIGURES = Path(_pc.FIGURES_DIR).resolve()

# ---------------------------------------------------------------------------
# Helper: script path relative to this Snakefile
# ---------------------------------------------------------------------------
PERF_DIR = _bench / "performance"

# ---------------------------------------------------------------------------
# Experiment parameters (read from config, overridable via smoke_test)
# ---------------------------------------------------------------------------
ONEKG_SUBSETS         = list(_pc.ONEKG_SUBSETS)
SYNTH_SCALES          = list(_pc.SYNTH_SCALES)
BUILD_SCALES          = list(_pc.BUILD_SCALES)
BUILD_THREAD_COUNTS   = list(_pc.BUILD_THREAD_COUNTS)
ANNOTATE_VARIANT_COUNTS = list(_pc.ANNOTATE_VARIANT_COUNTS)
ANNOTATE_THREAD_COUNTS  = list(_pc.ANNOTATE_THREAD_COUNTS)
BUILD_REPS            = _pc.BUILD_REPS
ANNOTATE_REPS         = _pc.ANNOTATE_REPS

# Smoke test: shrink parameter space for quick validation
if config.get("smoke_test"):
    ONEKG_SUBSETS           = [500]
    SYNTH_SCALES            = [1_000]
    BUILD_SCALES            = [1_000]
    BUILD_THREAD_COUNTS     = [1, 4]
    ANNOTATE_VARIANT_COUNTS = [10_000]
    ANNOTATE_THREAD_COUNTS  = [1, 4]
    BUILD_REPS              = 1
    ANNOTATE_REPS           = 1

BUILD_REPS_LIST    = list(range(1, BUILD_REPS + 1))
ANNOTATE_REPS_LIST = list(range(1, ANNOTATE_REPS + 1))

# Intermediate per-run outputs live under results/raw/ (or results/raw_smoke/)
PERF_RAW = PERF_RESULTS / ("raw_smoke" if config.get("smoke_test") else "raw")

# ---------------------------------------------------------------------------
# Group A: Data preparation — 8 parallel DB builds
# ---------------------------------------------------------------------------


rule performance_prepare_1kg:
    """Build one 1KG database at a given subset size."""
    wildcard_constraints:
        n=r"\d+",
    input:
        manifest=str(ONEKG_MANIFEST),
    output:
        done=str(PERF_RAW / "prepare" / "1kg_{n}.json"),
    log:
        "logs/performance/prepare_1kg_{n}.log",
    # Independent jobs per DB size run in parallel on SLURM;
    # threads is per-job maximum for a single node
    threads: 52
    resources:
        mem_mb=32_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 01_prepare_data.py \
            --mode 1kg --n-samples {wildcards.n} \
            --output {output.done} \
            > "$LOG" 2>&1
        """


rule performance_prepare_synth:
    """Build one synthetic database at a given sample count."""
    wildcard_constraints:
        n=r"\d+",
    output:
        done=str(PERF_RAW / "prepare" / "synth_{n}.json"),
    log:
        "logs/performance/prepare_synth_{n}.log",
    # Independent jobs per DB size run in parallel on SLURM;
    # threads is per-job maximum for a single node
    threads: 52
    resources:
        mem_mb=32_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 01_prepare_data.py \
            --mode synth --n-samples {wildcards.n} \
            --output {output.done} \
            > "$LOG" 2>&1
        """


rule performance_data_inventory:
    """Collect all per-DB build results into data_inventory.json."""
    input:
        onekg=expand(
            str(PERF_RAW / "prepare" / "1kg_{n}.json"),
            n=ONEKG_SUBSETS,
        ),
        synth=expand(
            str(PERF_RAW / "prepare" / "synth_{n}.json"),
            n=SYNTH_SCALES,
        ),
    output:
        inventory=str(PERF_RESULTS / "data_inventory.json"),
    log:
        "logs/performance/data_inventory.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python collect_prepare.py \
            --onekg-jsons {input.onekg} \
            --synth-jsons {input.synth} \
            --output {output.inventory} \
            > "$LOG" 2>&1
        """


# ---------------------------------------------------------------------------
# Group B: Query scaling — 5 parallel per-scale runs
# ---------------------------------------------------------------------------


rule performance_query_scaling_one:
    """Query scaling benchmark for one synthetic DB scale."""
    wildcard_constraints:
        n=r"\d+",
    input:
        done=str(PERF_RAW / "prepare" / "synth_{n}.json"),
    output:
        result=str(PERF_RAW / "query" / "query_{n}.json"),
    log:
        "logs/performance/query_{n}.log",
    threads: 1
    resources:
        mem_mb=16_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 02_query_scaling.py \
            --scale {wildcards.n} \
            --output {output.result} \
            > "$LOG" 2>&1
        """


rule performance_query_scaling_collect:
    """Merge per-scale query results into query_scaling.json."""
    input:
        results=expand(
            str(PERF_RAW / "query" / "query_{n}.json"),
            n=SYNTH_SCALES,
        ),
    output:
        results=str(PERF_RESULTS / "query_scaling.json"),
    log:
        "logs/performance/query_scaling_collect.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python collect_query_scaling.py \
            --inputs {input.results} \
            --output {output.results} \
            > "$LOG" 2>&1
        """


# ---------------------------------------------------------------------------
# Group C: Build performance — 3 synth gen + 15 parallel builds
# ---------------------------------------------------------------------------


rule performance_build_synth_gen:
    """Generate synthetic VCFs for one build-bench scale."""
    wildcard_constraints:
        n=r"\d+",
    output:
        done=str(PERF_RAW / "build" / "gen_{n}.json"),
    log:
        "logs/performance/build_gen_{n}.log",
    threads: 1
    resources:
        mem_mb=8_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 03_build.py \
            --phase gen \
            --n-samples {wildcards.n} \
            --output {output.done} \
            > "$LOG" 2>&1
        """


rule performance_build_one:
    """Run one (n_samples, threads, rep) build benchmark combination."""
    wildcard_constraints:
        n=r"\d+",
        t=r"\d+",
        r=r"\d+",
    input:
        gen=str(PERF_RAW / "build" / "gen_{n}.json"),
    output:
        result=str(PERF_RAW / "build" / "build_{n}_{t}t_r{r}.json"),
    log:
        "logs/performance/build_{n}_{t}t_r{r}.log",
    threads: lambda wc: int(wc.t)
    resources:
        mem_mb=lambda wc, threads: threads * 4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 03_build.py \
            --phase build \
            --n-samples {wildcards.n} \
            --threads {wildcards.t} \
            --output {output.result} \
            > "$LOG" 2>&1
        """


rule performance_build_collect:
    """Merge all build results into build_perf.json."""
    input:
        results=expand(
            str(PERF_RAW / "build" / "build_{n}_{t}t_r{r}.json"),
            n=BUILD_SCALES,
            t=BUILD_THREAD_COUNTS,
            r=BUILD_REPS_LIST,
        ),
    output:
        results=str(PERF_RESULTS / "build_perf.json"),
    log:
        "logs/performance/build_collect.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python collect_build_perf.py \
            --inputs {input.results} \
            --output {output.results} \
            > "$LOG" 2>&1
        """


# ---------------------------------------------------------------------------
# Group D: Annotation throughput — 15 parallel runs
# ---------------------------------------------------------------------------


rule performance_annotate_one:
    """Run one (n_variants, threads, rep) annotation benchmark."""
    wildcard_constraints:
        n=r"\d+",
        t=r"\d+",
        r=r"\d+",
    input:
        done=str(PERF_RAW / "prepare" / "1kg_{max_subset}.json".format(
            max_subset=max(ONEKG_SUBSETS)
        )),
    output:
        result=str(PERF_RAW / "annotate" / "annotate_{n}_{t}t_r{r}.json"),
    log:
        "logs/performance/annotate_{n}_{t}t_r{r}.log",
    threads: lambda wc: int(wc.t)
    resources:
        mem_mb=lambda wc, threads: max(8_000, threads * 2_000),
    params:
        max_subset=max(ONEKG_SUBSETS),
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 04_annotate.py \
            --n-variants {wildcards.n} \
            --threads {wildcards.t} \
            --max-subset {params.max_subset} \
            --output {output.result} \
            > "$LOG" 2>&1
        """


rule performance_annotate_collect:
    """Merge all annotation results and compute speedup ratios."""
    input:
        results=expand(
            str(PERF_RAW / "annotate" / "annotate_{n}_{t}t_r{r}.json"),
            n=ANNOTATE_VARIANT_COUNTS,
            t=ANNOTATE_THREAD_COUNTS,
            r=ANNOTATE_REPS_LIST,
        ),
    output:
        results=str(PERF_RESULTS / "annotate_throughput.json"),
    log:
        "logs/performance/annotate_collect.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python collect_annotate.py \
            --inputs {input.results} \
            --output {output.results} \
            > "$LOG" 2>&1
        """


# ---------------------------------------------------------------------------
# Group E: bcftools comparison — 3 parallel + 1 concordance
# ---------------------------------------------------------------------------


rule performance_bcftools_one:
    """Run bcftools comparison for one 1KG subset size."""
    wildcard_constraints:
        n=r"\d+",
    input:
        done=str(PERF_RAW / "prepare" / "1kg_{n}.json"),
    output:
        result=str(PERF_RAW / "bcftools" / "bcftools_{n}.json"),
    log:
        "logs/performance/bcftools_{n}.log",
    threads: 1
    resources:
        mem_mb=16_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 05_vs_bcftools.py \
            --n-samples {wildcards.n} \
            --output {output.result} \
            > "$LOG" 2>&1
        """


rule performance_bcftools_concordance:
    """Compute AF concordance on the largest 1KG subset."""
    input:
        done=str(PERF_RAW / "prepare" / "1kg_{n}.json".format(
            n=max(ONEKG_SUBSETS)
        )),
    output:
        concordance=str(PERF_RESULTS / "concordance.json"),
    log:
        "logs/performance/bcftools_concordance.log",
    params:
        max_subset=max(ONEKG_SUBSETS),
    threads: 1
    resources:
        mem_mb=16_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 05_vs_bcftools.py \
            --n-samples {params.max_subset} \
            --concordance-only \
            --output {output.concordance} \
            > "$LOG" 2>&1
        """


rule performance_bcftools_collect:
    """Merge per-subset bcftools results into bcftools_comparison.json."""
    input:
        results=expand(
            str(PERF_RAW / "bcftools" / "bcftools_{n}.json"),
            n=ONEKG_SUBSETS,
        ),
    output:
        comparison=str(PERF_RESULTS / "bcftools_comparison.json"),
    log:
        "logs/performance/bcftools_collect.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python collect_bcftools.py \
            --inputs {input.results} \
            --output {output.comparison} \
            > "$LOG" 2>&1
        """


# ---------------------------------------------------------------------------
# Plotting (unchanged)
# ---------------------------------------------------------------------------


rule performance_plot:
    """Generate publication figures from all result JSON files."""
    input:
        str(PERF_RESULTS / "data_inventory.json"),
        str(PERF_RESULTS / "query_scaling.json"),
        str(PERF_RESULTS / "build_perf.json"),
        str(PERF_RESULTS / "annotate_throughput.json"),
        str(PERF_RESULTS / "bcftools_comparison.json"),
        str(PERF_RESULTS / "concordance.json"),
    output:
        expand(
            str(PERF_FIGURES / "{name}.{ext}"),
            name=[
                "fig1_query_scaling",
                "fig2_build_perf",
                "fig3_annotate_throughput",
                "fig4_bcftools_comparison",
                "fig5_concordance",
                "fig6_disk_footprint",
            ],
            ext=["pdf", "png"],
        ),
    log:
        "logs/performance/plot.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {PERF_DIR} && python 06_plot.py > "$LOG" 2>&1
        """


rule performance_all:
    """Run the complete performance benchmark."""
    input:
        rules.performance_plot.output,
    output:
        touch(str(PERF_RESULTS / ".performance_done")),
