# GrapheonRL Demo Workflow -- Variant Calling Pipeline
#
# Mirrors the canonical Snakemake tutorial bioinformatics structure:
#   https://snakemake.readthedocs.io/en/stable/tutorial/tutorial.html
#
# Pipeline:
#   reference genome
#       |
#   [bwa_map x 6 samples]     ← fan-out: 6 parallel lanes, 4 cores each
#       |
#   [samtools_sort x 6]        ← 2 cores each
#       |
#   [samtools_index x 6]       ← 1 core each
#   [fastqc x 6]               ← 1 core each (independent QC branch)
#       |
#   joint_calling               ← synchronisation barrier: waits for ALL 6 samples
#       |
#   [scatter_filter x 4 chroms] ← scatter: 4 parallel chromosome filters
#       |
#   gather_variants             ← gather
#       |
#   annotate
#       |
#   multiqc                     ← joins QC + annotation results
#
# Total: 1 + 6*4 + 4 + 3 = 32 tasks
#
# Scheduling challenge:
#   - Pre-barrier: 24 tasks, must maximise parallelism under 8-core budget
#   - BWA uses 4 cores each: only 2 samples can map in parallel
#   - Critical path: ref → bwa_map → sort → index → joint → scatter → gather → annotate
#   - HEFT/GNNRL must schedule mapping before light QC tasks to avoid blocking
#     the barrier with core starvation

import time
import os

SAMPLES    = [f"S{i}" for i in range(1, 7)]   # 6 samples (like tutorial A,B,...)
CHROMOSOMES = ["chr1", "chr2", "chr3", "chr4"]

onstart:
    os.makedirs("results/mapped", exist_ok=True)
    os.makedirs("results/sorted", exist_ok=True)
    os.makedirs("results/qc",     exist_ok=True)
    os.makedirs("results/filtered", exist_ok=True)
    with open("results/.start_time", "w") as f:
        f.write(str(time.time()))

onsuccess:
    with open("results/.end_time", "w") as f:
        f.write(str(time.time()))
    start = float(open("results/.start_time").read())
    end   = float(open("results/.end_time").read())
    print(f"\n  Workflow wall time: {end - start:.1f}s  "
          f"({len(SAMPLES)} samples, {len(CHROMOSOMES)} chromosomes)")

rule all:
    input:
        "results/multiqc_report.txt"

# ── Reference genome (shared input) ──────────────────────────────────────

rule download_ref:
    output: "results/ref.fa"
    threads: 1
    resources: mem_mb=512
    shell: "sleep 1 && echo '>chr1\nACGT' > {output}"

# ── Per-sample: map → sort → index ───────────────────────────────────────
# These simulate: bwa mem (4 cores), samtools sort (2 cores), samtools index (1 core)

rule bwa_map:
    input:
        ref="results/ref.fa"
    output: "results/mapped/{sample}.bam"
    threads: 4
    resources: mem_mb=4096
    shell: "sleep 3 && echo 'mapped:{wildcards.sample}' > {output}"

rule samtools_sort:
    input: "results/mapped/{sample}.bam"
    output: "results/sorted/{sample}.bam"
    threads: 2
    resources: mem_mb=2048
    shell: "sleep 2 && echo 'sorted:{wildcards.sample}' > {output}"

rule samtools_index:
    input: "results/sorted/{sample}.bam"
    output: "results/sorted/{sample}.bam.bai"
    threads: 1
    resources: mem_mb=1024
    shell: "sleep 1 && echo 'indexed:{wildcards.sample}' > {output}"

# ── Per-sample QC (parallel branch, independent of calling) ──────────────
# Simulates: FastQC on mapped reads

rule fastqc:
    input: "results/mapped/{sample}.bam"
    output: "results/qc/{sample}_fastqc.txt"
    threads: 1
    resources: mem_mb=1024
    shell: "sleep 1 && echo 'qc:{wildcards.sample}' > {output}"

# ── Joint variant calling (synchronisation barrier) ──────────────────────
# Simulates: GATK HaplotypeCaller on all samples together

rule joint_calling:
    input:
        bams   = expand("results/sorted/{sample}.bam",     sample=SAMPLES),
        indices= expand("results/sorted/{sample}.bam.bai", sample=SAMPLES)
    output: "results/raw_variants.vcf"
    threads: 4
    resources: mem_mb=8192
    shell: "sleep 4 && echo 'CHROM POS REF ALT' > {output}"

# ── Scatter: per-chromosome filtering ────────────────────────────────────
# Simulates: bcftools filter per chromosome (can run in parallel after calling)

rule filter_variants:
    input: "results/raw_variants.vcf"
    output: "results/filtered/{chrom}.vcf"
    threads: 2
    resources: mem_mb=2048
    shell: "sleep 2 && echo 'filtered:{wildcards.chrom}' > {output}"

# ── Gather ────────────────────────────────────────────────────────────────

rule gather_variants:
    input: expand("results/filtered/{chrom}.vcf", chrom=CHROMOSOMES)
    output: "results/final_variants.vcf"
    threads: 1
    resources: mem_mb=1024
    shell: "sleep 1 && cat {input} > {output}"

# ── Annotation ────────────────────────────────────────────────────────────

rule annotate:
    input: "results/final_variants.vcf"
    output: "results/annotated.vcf"
    threads: 2
    resources: mem_mb=2048
    shell: "sleep 2 && echo 'ANNOTATED' > {output}"

# ── MultiQC: joins QC + annotation (final aggregation) ───────────────────

rule multiqc:
    input:
        qc  = expand("results/qc/{sample}_fastqc.txt", sample=SAMPLES),
        vcf = "results/annotated.vcf"
    output: "results/multiqc_report.txt"
    threads: 1
    resources: mem_mb=512
    shell: "sleep 1 && echo 'MultiQC report: {input}' > {output}"
