from pathlib import Path
import yaml

configfile: "config.yaml"
with open("config.yaml", 'r') as f:
    cfg = yaml.safe_load(f) or {}

# -------------------- Multi-sample support --------------------
def get_samples():
    """Get list of samples from config"""
    if "samples" in cfg:
        return cfg["samples"]
    elif "input_fasta" in cfg:
        return [Path(cfg["input_fasta"]).stem]
    else:
        return []

def get_batch_mode(step):
    """Get batch mode for a step"""
    return cfg.get("batch_modes", {}).get(step, "single")

def get_input_files(step):
    """Get input files for a step, handling both single and batch modes"""
    if step == "blast_remote":
        batch_mode = get_batch_mode("blast_remote")
        if batch_mode == "multi" and "input_fastas" in cfg:
            return cfg["input_fastas"]
        else:
            # For single mode or fallback, create a mapping from samples to the single input file
            input_fasta = cfg.get("input_fasta", "")
            return {sample: input_fasta for sample in SAMPLES}
    
    elif step == "mafft_conserved":
        batch_mode = get_batch_mode("mafft_conserved")
        if batch_mode == "multi" and "custom_blast_inputs" in cfg:
            return cfg["custom_blast_inputs"]
        else:
            # Default mapping - each sample to its expected blast output
            default_single = cfg.get("custom_blast_input", "")
            if default_single and batch_mode == "single":
                return {sample: default_single for sample in SAMPLES}
            else:
                return {sample: f"blast/{sample}.xml" for sample in SAMPLES}
    
    elif step == "primer3_design":
        batch_mode = get_batch_mode("primer3_design")
        if batch_mode == "multi" and "custom_bed_inputs" in cfg:
            return cfg["custom_bed_inputs"]
        else:
            # Default mapping - each sample to its expected bed output
            default_single = cfg.get("custom_bed_input", "")
            if default_single and batch_mode == "single":
                return {sample: default_single for sample in SAMPLES}
            else:
                return {sample: f"mafft/{sample}.conserved.bed" for sample in SAMPLES}
    
    elif step == "specificity":
        batch_mode = get_batch_mode("specificity")
        # Only use custom inputs if they exist on disk, otherwise fall back to pipeline outputs
        if batch_mode == "multi" and "custom_primer_inputs" in cfg:
            custom_inputs = cfg["custom_primer_inputs"]
            # Check if custom files exist, if not fall back to pipeline output
            result = {}
            for sample in SAMPLES:
                custom_path = custom_inputs.get(sample, "")
                if custom_path and Path(custom_path).exists():
                    result[sample] = custom_path
                else:
                    result[sample] = f"primer/{sample}.raw.tsv"
            return result
        else:
            # Default mapping - each sample to its expected primer output
            default_single = cfg.get("custom_primer_input", "")
            if default_single and batch_mode == "single" and Path(default_single).exists():
                return {sample: default_single for sample in SAMPLES}
            else:
                return {sample: f"primer/{sample}.raw.tsv" for sample in SAMPLES}
    
    return {}

SAMPLES = get_samples()

# Default to first sample if needed for single mode
DEFAULT_SAMPLE = SAMPLES[0] if SAMPLES else "sample"

# -------------------- Rule All --------------------
rule all:
    input:
        # Generate targets based on the final step we want to reach
        expand("primer/{sample}.specificity.tsv", sample=SAMPLES) if SAMPLES else []

# Additional convenience rules for partial runs
rule blast_remote_all:
    input:
        expand("blast/{sample}.xml", sample=SAMPLES) if SAMPLES else []

rule mafft_all:
    input:
        expand("mafft/{sample}.conserved.bed", sample=SAMPLES) if SAMPLES else []

rule primer3_all:
    input:
        expand("primer/{sample}.raw.tsv", sample=SAMPLES) if SAMPLES else []

# Alternative target rules that avoid full DAG validation
rule run_blast_only:
    """Run only BLAST step without checking downstream dependencies"""
    input:
        expand("blast/{sample}.xml", sample=SAMPLES) if SAMPLES else []

# -------------------- Rules --------------------

rule blast_remote:
    input:
        fa=lambda wc: get_input_files("blast_remote").get(wc.sample, cfg.get("input_fasta", ""))
    output:
        xml="blast/{sample}.xml"
    log:
        "logs/blast/{sample}.log"
    params:
        email=cfg["ncbi"]["email"],
        key=cfg["ncbi"].get("api_key", ""),
        db=cfg["ncbi"].get("db", "nt"),
        task=cfg["ncbi"].get("task", "megablast"),
        hit=cfg["ncbi"].get("hitlist", 120),
        ev=1e-5,
        ent=cfg["ncbi"].get("entrez_query"),
        entrez_flag=lambda wc: f"--entrez-query \"{cfg['ncbi'].get('entrez_query')}\"" if cfg["ncbi"].get("entrez_query") else ""
    conda:
        "envs/biopy.yaml"
    shell:
        r"""
        export PYTHONPATH=$PWD:${{PYTHONPATH:-}}
        python scripts/remote_blast.py {input.fa} {output.xml} \
          --email {params.email} --api-key {params.key} --db {params.db} \
          --task {params.task} --hitlist-size {params.hit} --evalue {params.ev} {params.entrez_flag} 2>&1 | tee {log}
        """

rule mafft_conserved:
    input:
        fa=lambda wc: get_input_files("blast_remote").get(wc.sample, cfg.get("input_fasta", "")),
        xml=lambda wc: get_input_files("mafft_conserved").get(wc.sample, cfg.get("custom_blast_input", f"blast/{wc.sample}.xml"))
    output:
        aln="mafft/{sample}.aln.fasta",
        bed="mafft/{sample}.conserved.bed"
    log:
        "logs/mafft/{sample}.log"
    params:
        email=cfg["ncbi"]["email"],
        key=cfg["ncbi"].get("api_key", ""),
        maxh=cfg.get("entropy", {}).get("max_hits", 100),
        win=cfg.get("entropy", {}).get("window", 20),
        step=cfg.get("entropy", {}).get("step", 4),
        metric=cfg.get("entropy", {}).get("metric", "entropy"),
        cut=cfg.get("entropy", {}).get("cutoff", 0.9),
        colmc=cfg.get("entropy", {}).get("col_min_coverage", 0.8),
        winmc=cfg.get("entropy", {}).get("win_min_cov_frac", 0.8),
        lcr=cfg.get("entropy", {}).get("lcr_filter", True),
        lcrH=cfg.get("entropy", {}).get("lcr_min_entropy_bits", 1.2),
        lcrP=cfg.get("entropy", {}).get("lcr_max_homopolymer", 5),
        min_qcov=cfg.get("homology_filter", {}).get("min_qcov", 0.7),
        min_pident=cfg.get("homology_filter", {}).get("min_pident", 0.8),
        min_len=cfg.get("homology_filter", {}).get("min_len", 20),
        max_len=cfg.get("homology_filter", {}).get("max_len", 10000),
        max_N_frac=cfg.get("homology_filter", {}).get("max_N_frac", 0.10),
        per_species_one=cfg.get("homology_filter", {}).get("per_species_one", True),
        downweight_duplicates=cfg.get("homology_filter", {}).get("downweight_duplicates", True),
        max_after_filter=cfg.get("homology_filter", {}).get("max_after_filter", 60),
        pre_crop=cfg.get("homology_filter", {}).get("pre_crop", True),
        pad=cfg.get("homology_filter", {}).get("pad", 50),
        lcr_flag=lambda wc: "--lcr-filter" if cfg.get("entropy", {}).get("lcr_filter", True) else "--no-lcr-filter",
        species_flag=lambda wc: "--per-species-one" if cfg.get("homology_filter", {}).get("per_species_one", True) else "--no-per-species-one",
        dup_flag=lambda wc: "--downweight-duplicates" if cfg.get("homology_filter", {}).get("downweight_duplicates", True) else "--no-downweight-duplicates",
        crop_flag=lambda wc: "--pre-crop" if cfg.get("homology_filter", {}).get("pre_crop", True) else "--no-pre-crop"
    conda:
        "envs/mafft.yaml"
    shell:
        r"""
        export PYTHONPATH=$PWD:${{PYTHONPATH:-}}
        python scripts/mafft_pick_conserved.py {input.fa} {input.xml} {output.aln} {output.bed} \
          --email {params.email} --api-key {params.key} --max-hits {params.maxh} \
          --window {params.win} --step {params.step} --metric {params.metric} --identity {params.cut} \
          --col-min-coverage {params.colmc} --win-min-cov-frac {params.winmc} \
          {params.lcr_flag} --lcr-min-entropy-bits {params.lcrH} --lcr-max-homopolymer {params.lcrP} \
          --min-qcov {params.min_qcov} --min-pident {params.min_pident} \
          --min-len {params.min_len} --max-len {params.max_len} \
          --max-N-frac {params.max_N_frac} \
          {params.species_flag} \
          {params.dup_flag} \
          --max-after-filter {params.max_after_filter} \
          {params.crop_flag} --pad {params.pad} 2>&1 | tee {log}
        """

rule primer3_design:
    input:
        fa=lambda wc: get_input_files("blast_remote").get(wc.sample, cfg.get("input_fasta", "")),
        bed=lambda wc: get_input_files("primer3_design").get(wc.sample, cfg.get("custom_bed_input", f"mafft/{wc.sample}.conserved.bed"))
    output:
        tsv="primer/{sample}.raw.tsv"
    log:
        "logs/primer3/{sample}.log"
    params:
        left_bind_flag=lambda wc: f"--left-bind \"{cfg.get('p3', {}).get('left_bind')}\"" if cfg.get('p3', {}).get('left_bind') else "",
        right_bind_flag=lambda wc: f"--right-bind \"{cfg.get('p3', {}).get('right_bind')}\"" if cfg.get('p3', {}).get('right_bind') else "",
        mode=lambda wc: "include" if not cfg.get('p3', {}).get('left_bind') and not cfg.get('p3', {}).get('right_bind') else "target",
        size_range=cfg.get("p3", {}).get('size_range', '90-180'),
        num_return=cfg.get("p3", {}).get('num_return', 5),
        opt_size=cfg.get("p3", {}).get('opt_size', 20),
        min_size=cfg.get("p3", {}).get('min_size', 18),
        max_size=cfg.get("p3", {}).get('max_size', 27),
        opt_tm=cfg.get("p3", {}).get('opt_tm', 60.0),
        min_tm=cfg.get("p3", {}).get('min_tm', 57.0),
        max_tm=cfg.get("p3", {}).get('max_tm', 63.0),
        min_gc=cfg.get("p3", {}).get('min_gc', 35.0),
        max_gc=cfg.get("p3", {}).get('max_gc', 65.0),
        gc_clamp=cfg.get("p3", {}).get('gc_clamp', 0),
        max_polyx=cfg.get("p3", {}).get('max_polyx', 4),
        salt_monovalent=cfg.get("p3", {}).get('salt_monovalent', 50.0),
        salt_divalent=cfg.get("p3", {}).get('salt_divalent', 1.5),
        dntp_conc=cfg.get("p3", {}).get('dntp_conc', 0.6),
        max_self_any_th=cfg.get("p3", {}).get('max_self_any_th', 45.0),
        max_self_end_th=cfg.get("p3", {}).get('max_self_end_th', 35.0),
        max_hairpin_th=cfg.get("p3", {}).get('max_hairpin_th', 24.0),
        conserved_flag=lambda wc: "--use-conserved-windows" if cfg.get("p3", {}).get('use_conserved_windows', True) else "--no-use-conserved-windows"
    conda:
        "envs/primer3.yaml"
    shell:
        r"""
        export PYTHONPATH=$PWD:${{PYTHONPATH:-}}
        python scripts/primer3_design.py {input.fa} {input.bed} {output.tsv} \
          --mode {params.mode} \
          {params.left_bind_flag} \
          {params.right_bind_flag} \
          {params.conserved_flag} \
          --size-range "{params.size_range}" \
          --num-return {params.num_return} \
          --opt-size {params.opt_size} --min-size {params.min_size} --max-size {params.max_size} \
          --opt-tm {params.opt_tm} --min-tm {params.min_tm} --max-tm {params.max_tm} \
          --min-gc {params.min_gc} --max-gc {params.max_gc} \
          --gc-clamp {params.gc_clamp} --max-polyx {params.max_polyx} \
          --salt-monovalent {params.salt_monovalent} --salt-divalent {params.salt_divalent} --dntp-conc {params.dntp_conc} \
          --max-self-any-th {params.max_self_any_th} --max-self-end-th {params.max_self_end_th} --max-hairpin-th {params.max_hairpin_th} \
          > {log} 2>&1
        """

rule specificity:
    input:
        tsv=lambda wc: get_input_files("specificity").get(wc.sample, f"primer/{wc.sample}.raw.tsv")
    output:
        tsv2="primer/{sample}.specificity.tsv"
    log:
        "logs/specificity/{sample}.log"
    params:
        email=cfg["ncbi"]["email"],
        key=cfg["ncbi"].get("api_key", ""),
        db=cfg.get("specificity", {}).get("db", "nt"),
        hitlist=cfg.get("specificity", {}).get("hitlist", 1000),
        size_range=cfg.get("specificity", {}).get("size_range", "90-200"),
        min_tail=cfg.get("specificity", {}).get("min_tail", 10),
        entrez_include_flag=lambda wc: f'--entrez-include "{cfg.get("specificity", {}).get("entrez_include")}"' if cfg.get("specificity", {}).get("entrez_include") else "",
        entrez_exclude_flag=lambda wc: f'--entrez-exclude "{cfg.get("specificity", {}).get("entrez_exclude")}"' if cfg.get("specificity", {}).get("entrez_exclude") else "",
        taxid_include_flag=lambda wc: f'--taxid-include {cfg.get("specificity", {}).get("taxid_include")}' if cfg.get("specificity", {}).get("taxid_include") else "",
        taxid_exclude_flag=lambda wc: f'--taxid-exclude {cfg.get("specificity", {}).get("taxid_exclude")}' if cfg.get("specificity", {}).get("taxid_exclude") else "",
        # Performance optimization flags
        fast_mode_flag=lambda wc: "--fast-mode" if cfg.get("specificity", {}).get("fast_mode", False) else "",
        max_primers_flag=lambda wc: f'--max-primers {cfg.get("specificity", {}).get("max_primers")}' if cfg.get("specificity", {}).get("max_primers") and not cfg.get("specificity", {}).get("check_all_primers", True) else "",
    conda:
        "envs/biopy.yaml"
    shell:
        r"""
        python scripts/check_specificity.py {input.tsv} {output.tsv2} \
          --email {params.email} --api-key {params.key} \
          --db {params.db} \
          --hitlist {params.hitlist} \
          --size-range "{params.size_range}" \
          --min-tail {params.min_tail} \
          {params.entrez_include_flag} \
          {params.entrez_exclude_flag} \
          {params.taxid_include_flag} \
          {params.taxid_exclude_flag} \
          {params.fast_mode_flag} \
          {params.max_primers_flag} \
          > {log} 2>&1
        """