"""Snakemake rules for the capture kit mixing benchmark.

Included by the root benchmarks/Snakefile.
"""

import json as _json
import os
import sys
import importlib.util as _ilu
from pathlib import Path

# Create log directory before any rule runs
onstart:
    os.makedirs("logs/capture_kit", exist_ok=True)

_bench = Path(workflow.basedir)
sys.path.insert(0, str(_bench))

_spec = _ilu.spec_from_file_location("capkit_config", _bench / "capture_kit" / "config.py")
_ck = _ilu.module_from_spec(_spec)
_spec.loader.exec_module(_ck)

CK_RESULTS = Path(_ck.RESULTS_DIR)
CK_FIGURES = Path(_ck.FIGURES_DIR)
CK_DB_DIR  = Path(_ck.DB_DIR)
CK_MANIFEST_DIR = Path(_ck.MANIFEST_DIR)

CK_DIR = _bench / "capture_kit"

SCENARIOS = list(_ck.SCENARIOS.keys())
if config.get("smoke_test"):
    SCENARIOS = ["balanced"]
CK_SCENARIOS_ALL = ["wgs"] + SCENARIOS


checkpoint capture_kit_assign:
    """Read 1KG panel, subsample N samples, write assignments.json."""
    input:
        manifest=str(ONEKG_MANIFEST),
    output:
        assignments=str(CK_MANIFEST_DIR / "assignments.json"),
    log:
        "logs/capture_kit/assign_samples.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {CK_DIR} && python 01a_assign_samples.py > "$LOG" 2>&1
        """


rule capture_kit_mask_vcf:
    """Mask one per-sample VCF to keep only variants within a kit's BED regions."""
    input:
        vcf=str(_ck.ONEKG_VCF_DIR / "{sample}.vcf.gz"),
        bed=lambda wc: str(_ck.MASKING_BED_DIR / f"{wc.tech}.bed"),
    output:
        vcf=str(_ck.MASKED_DIR / "{tech}" / "{sample}.vcf.gz"),
    log:
        "logs/capture_kit/mask/{tech}/{sample}.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" "$(dirname "{output.vcf}")"
        zcat {input.vcf} \
          | awk 'BEGIN{{OFS="\\t"}} /^#/ {{print; next}} {{$1="chr"$1; print}}' \
          | bedtools intersect -a stdin -b {input.bed} -header -wa \
          | bgzip > {output.vcf} 2>"$LOG"
        """


def _ck_all_split_vcfs(wildcards):
    asgn = checkpoints.capture_kit_assign.get(**wildcards).output.assignments
    data = _json.loads(Path(asgn).read_text())
    return [str(_ck.ONEKG_VCF_DIR / f"{s}.vcf.gz") for s, _ in data["samples"]]


def _ck_all_masked_vcfs(wildcards):
    asgn = checkpoints.capture_kit_assign.get(**wildcards).output.assignments
    data = _json.loads(Path(asgn).read_text())
    pairs = {
        (tech, sample)
        for sc in data["scenarios"].values()
        for sample, tech in sc.items()
    }
    return [str(_ck.MASKED_DIR / tech / f"{sample}.vcf.gz") for tech, sample in pairs]


rule capture_kit_write_manifests:
    """Write 4 manifest TSVs (wgs + 3 WES scenarios) after all masking is done."""
    input:
        assignments=str(CK_MANIFEST_DIR / "assignments.json"),
        split_vcfs=_ck_all_split_vcfs,
        masked_vcfs=_ck_all_masked_vcfs,
    output:
        manifests=expand(
            str(CK_MANIFEST_DIR / "manifest_{scenario}.tsv"),
            scenario=["wgs"] + SCENARIOS,
        ),
        done=touch(str(CK_MANIFEST_DIR / ".prepare_done")),
    log:
        "logs/capture_kit/write_manifests.log",
    threads: 1
    resources:
        mem_mb=1_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {CK_DIR} && python 01b_write_manifests.py > "$LOG" 2>&1
        """


rule capture_kit_build_one:
    """Build one AFQuery database for a capture kit scenario."""
    wildcard_constraints:
        scenario=r"\w+",
    input:
        manifest=str(CK_MANIFEST_DIR / "manifest_{scenario}.tsv"),
    output:
        done=touch(str(CK_DB_DIR / "db_{scenario}" / ".build_done")),
    log:
        "logs/capture_kit/build_{scenario}.log",
    threads: 52
    resources:
        mem_mb=64_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {CK_DIR} && python 02_build_databases.py \
            --scenario {wildcards.scenario} \
            > "$LOG" 2>&1
        """


rule capture_kit_build_all:
    """Wait for all capture kit DBs to be built."""
    input:
        done=expand(
            str(CK_DB_DIR / "db_{scenario}" / ".build_done"),
            scenario=CK_SCENARIOS_ALL,
        ),
    output:
        done=touch(str(CK_DB_DIR / ".build_done")),
    log:
        "logs/capture_kit/build_all.log",


rule capture_kit_metrics:
    """Dump variants, compute NADI and AF error metrics."""
    input:
        rules.capture_kit_build_all.output.done,
    output:
        parquet=str(CK_RESULTS / "merged.parquet"),
        summary=str(CK_RESULTS / "nadi_summary.json"),
    log:
        "logs/capture_kit/compute_metrics.log",
    threads: 4
    resources:
        mem_mb=32_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {CK_DIR} && python 03_compute_metrics.py > "$LOG" 2>&1
        """


rule capture_kit_acmg:
    """Apply ACMG thresholds and compute misclassification rates."""
    input:
        parquet=rules.capture_kit_metrics.output.parquet,
    output:
        results=str(CK_RESULTS / "acmg_results.json"),
    log:
        "logs/capture_kit/classify_acmg.log",
    threads: 1
    resources:
        mem_mb=8_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {CK_DIR} && python 04_classify_acmg.py > "$LOG" 2>&1
        """


rule capture_kit_plot:
    """Generate publication figures."""
    input:
        parquet=rules.capture_kit_metrics.output.parquet,
        acmg=rules.capture_kit_acmg.output.results,
    output:
        expand(
            str(CK_FIGURES / "{name}.{ext}"),
            name=[
                "fig_capkit_af_error_violin",
                "fig_capkit_error_by_coverage",
                "fig_capkit_scatter",
                "fig_capkit_acmg",
                "fig_capkit_toward_pathogenic",
                "fig_capkit_an_ratio",
            ],
            ext=["pdf", "png"],
        ),    
    log:
        "logs/capture_kit/plot_figures.log",
    threads: 1
    resources:
        mem_mb=4_000,
    shell:
        """
        set -euo pipefail
        LOG=$(realpath {log}) && mkdir -p "$(dirname "$LOG")" && cd {CK_DIR} && python 05_plot_figures.py > "$LOG" 2>&1
        """


rule capture_kit_all:
    """Run the complete capture kit mixing benchmark."""
    input:
        rules.capture_kit_plot.output,
    output:
        touch(str(CK_RESULTS / ".capture_kit_done")),
