
from snakemake.utils import Paramspace, validate, min_version
import pandas as pd
from os import makedirs
import os.path
from aux.checkConfig import readDefaults, makeConfig
from aux.checkPatients import checkPatientsBetas, checkPatientInfo
from aux.checkInput import checkInput
from aux.validate_beast import validate_beast_command
import random

min_version("6.10.0")

DEFAULTS = readDefaults(os.path.join(workflow.basedir, "config.schema.yml"))
config = makeConfig(config, DEFAULTS)
validate(config, "config.schema.yml")
validate_beast_command()

workdir: config.get("workdir", ".")

XMLgenerator = workflow.source_path("aux/methylationBetas2xml.py")
_ = workflow.source_path("aux/createXML.py"), workflow.source_path("aux/readInputMethylation.py")
readIDAT = workflow.source_path("aux/readIDAT.R")
methCNV = workflow.source_path("aux/methCNV.R")
selectCpGs = workflow.source_path("aux/select_cpgs.py")
plot_tree = workflow.source_path("aux/plot_tree.R")
plot_post = workflow.source_path("aux/plot_posterior_predictive.py")
convergence = workflow.source_path("aux/mixingAndConvergence.R")
gather_MLE = workflow.source_path("aux/gatherMLE.sh")
model_selection = workflow.source_path("aux/modelSelectionS.R")
integrate_over_S = workflow.source_path("aux/integratingOverS.R")
manifest = workflow.source_path("data/MethylationEPIC_v-1-0_B4.csv")
crossreactivefile = workflow.source_path("data/13059_2016_1066_MOESM1_ESM.csv")

patient_col = config.get("patient_col", "Patient")
sample_col = config.get("sample_col", "Sample")
age_col = config.get("age_col", "Age")
age_diagnosis_col = config.get("age_diagnosis_col"),
sample_type_col = config.get("sample_type_col"),

df = pd.read_csv(config.get("patientInfo"))
df = checkPatientInfo(df, patient_col)
PATIENTS = df[patient_col].unique()
SAMPLES = df.groupby(patient_col)[sample_col].agg(list)
AGES = df.groupby(patient_col)[age_col].max().to_dict()

stemCellsSweep=config.get("stemCells").split("-")
STEMCELLS=list(range(int(stemCellsSweep[0]), int(stemCellsSweep[1]), int(stemCellsSweep[2])))
CHAIN_LENGTH = config.get("iterations", 750_000)

NCHAINS = config.get("nchains", 3)
SEED = config.get("seed", random.randint(0, 100000))
random.seed(SEED)

seeds = [random.randint(0, 100000) for _ in range(NCHAINS)] if NCHAINS > 1 else SEED

mode = checkInput(config)
activateCNV = False

if mode == "trees":
    ruleorder: split_flucpgs_trees > split_flucpgs
elif mode == "complete":
    activateCNV = True if config.get("cnv") and df['Group'].str.lower().str.count('normal|control').sum() > 2 else False

    ruleorder: split_flucpgs > split_flucpgs_trees
else:
    raise RuntimeError(f"Unexpected inferred mode {mode}")

if config.get("fcpgs"):
    ruleorder: force_cpgs > select_cpgs
else:
    ruleorder: select_cpgs > force_cpgs

if NCHAINS > 1:
    ruleorder: logcombiner > logcombiner_one
else:
    ruleorder: logcombiner_one > logcombiner

rule all:
    input: 
        expand("results/{patient}/selected_model/", patient=PATIENTS),
        expand("results/{patient}/{stemcell}cells/{patient}.posterior.pdf", patient=PATIENTS, stemcell=STEMCELLS),
        expand("results/{patient}/{stemcell}cells/{patient}.consensus.tree", patient=PATIENTS, stemcell=STEMCELLS),
        expand("results/{patient}/{stemcell}cells/{patient}.tree.pdf/", patient=PATIENTS, stemcell=STEMCELLS),
        expand("results/{patient}/integrated_solution/{patient}.consensus.tree", patient=PATIENTS),
        expand("results/{patient}/integrated_solution/{patient}.tree.pdf", patient=PATIENTS),
        expand("results/{patient}/integrated_solution/{patient}.posterior.pdf", patient=PATIENTS)
    
rule readIDAT:
    input:
        dir = config.get("input"),
        samplesheet = config.get("patientInfo")
    output:
        directory("idat_processed"),
    params:
        name = config.get("output"),
        readIDAT = readIDAT,
        sample_col = sample_col
    threads:
        workflow.cores
    shell:
        """
        Rscript {params.readIDAT} -i {input.dir} -o {output} -n {params.name} -p {input.samplesheet} -c {threads} -s {params.sample_col}
        """

rule callCNVs:
    input:
        rules.readIDAT.output
    output:
        expand("{out}.blacklist.csv", out=config.get("output"))
    threads:
        workflow.cores,
    params:
        methCNV=methCNV,
        genome_plot = "--genome_plot" if config.get("genome_plot") else "",
        name = config.get("output")
    shell:
        """
        Rscript {params.methCNV} -i {input}/{params.name}.RData -c {threads} {params.genome_plot} 
        rm {input}/{params.name}.RData
        mv {input}/{output} {output}
        """

rule select_cpgs:
    input: "idat_processed"
    output: expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output"))
    params:
        manifest = config.get("manifest", manifest),
        crossreactivefile = crossreactivefile,
        percent = config.get("percent", 5),
        patientInfo = config.get("patientInfo"),
        name = config.get("output"),
        patient_col = patient_col,
        sample_col = sample_col,
        selectCpGs = selectCpGs
    shell:
        """
        python3 {params.selectCpGs} -e {params.manifest} -p {params.percent} -c {crossreactivefile} {output} {input}/{params.name}.betas.csv {input}/{params.name}.M.csv {input}/{params.name}.U.csv {params.patientInfo} -P {params.patient_col} -S {params.sample_col}
        """

rule force_cpgs:
    input: 
        betas = "idat_processed",
        cpgs = config.get("fcpgs", "filenotfound")
    output: 
        flucpgs = expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output")),
        cpg_ids = temp("cpgs.tmp")
    shell:
        """
        grep -o 'cg[0-9]*' {input.cpgs} > {output.cpg_ids}
        head -n1 {input}/{params.name}.betas.csv > {output.flucpgs}
        fgrep -f {output.cpg_ids} {input}/{params.name}.betas.csv >> {output.flucpgs}
        """

checkpoint split_flucpgs:
    input: expand("{OUTPUT}.fluCpGs.csv", OUTPUT = config.get("output")),
    output: temp(directory("miniBetas"))
    run:
        df = pd.read_csv(input[0])
        df.set_index(df.columns[0], inplace=True)
        makedirs(output[0], exist_ok = True)
        for patient in PATIENTS:
            samples = SAMPLES[patient]
            df.filter(items = samples).to_csv(f"{output[0]}/{patient}.csv")

checkpoint split_flucpgs_trees:
    input: config.get("input")
    output: temp(directory("miniBetas"))
    run:
        df = pd.read_csv(input[0])
        df.set_index(df.columns[0], inplace=True)
        
        makedirs(output[0], exist_ok = True)
        for patient in PATIENTS:
            samples = SAMPLES[patient]
            df.filter(items = samples).to_csv(f"{output[0]}/{patient}.csv")


def get_blacklists():
    files = [crossreactivefile]
    if activateCNV:
        files.append(rules.callCNVs.output)
    
    if config.get("blacklist"):
        files.append(config.get("blacklist"))

    return files


rule merge_blacklists:
    input: get_blacklists()
    output: temp("blacklist.csv")
    shell:
        """
        cat {input} | tr ',' '\t' | grep -o 'cg[0-9]*' > {output}
        """

def aggregate_input(wildcards):
    if mode == "complete":
        checkpoint_output = checkpoints.split_flucpgs.get(**wildcards).output[0]
    else:
        checkpoint_output = checkpoints.split_flucpgs_trees.get(**wildcards).output[0]  
    return "miniBetas/{patient}.csv"

rule getXML:
    input: 
        minibetas=lambda wildcards: aggregate_input(wildcards),
        blacklist = rules.merge_blacklists.output,
        patientInfo = config.get("patientInfo"),
    output: 
        xml="results/XMLs/{patient}.{stemcell}cells.xml",
        tmp_input = temp("{patient}.{stemcell}cells.csv.tmp")
    params:
        delta = config.get("delta", .2),
        eta = config.get("eta", .7),
        kappa = config.get("kappa", 50),
        mu = config.get("mu", .1),
        gamma = config.get("gamma", .1),
        lam = config.get("lam", 1),
        iterations = CHAIN_LENGTH,
        precision = config.get("precision", 6),
        sampling = config.get("sampling", 75),
        screenSampling = config.get("screenSampling", 500),
        stripRownames = "--stripRownames" if config.get("stripRownames") else "",
        mle_ps = "--mle-ps" if config.get("mle_ps") else "",
        mle_ss = "--mle-ss" if config.get("mle_ss") else "",
        hme = "--hme" if config.get("hme") else "",
        mle_steps = config.get("mle_steps", 100),
        mle_iterations = config.get("mle_iterations", int(config.get("iterations", 750_000)/config.get("mle_steps", 100))),
        mle_sampling = config.get("mle_sampling", int(config.get("iterations", 750_000)/1000)),
        luca_mode = config.get("luca_mode", "auto"),
        age_col = age_col,
        age_diagnosis_col = f"--age-diagnosis-col {age_diagnosis_col}" if age_diagnosis_col else "",
        sample_col = sample_col,
        sample_type_col = sample_type_col,        
        XMLgenerator = XMLgenerator
    shell:
        """
        input=$(realpath {input.minibetas})
        blacklist=$(realpath {input.blacklist})
        tmp="$(pwd)/{output.tmp_input}"
        fgrep -v -f $blacklist $input > $tmp

        cd $(dirname {output.xml})
        name=$(basename {output.xml})
        name=${{name%.xml}}

        python3 {params.XMLgenerator} --samplesheet {input.patientInfo} --input $tmp --output $name --luca-mode {params.luca_mode} --age-col {params.age_col} --sample-col {params.sample_col} --sample-type-col {params.sample_type_col} --stemCells {wildcards.stemcell} --delta {params.delta} --eta {params.eta} --kappa {params.kappa} --mu {params.mu} --gamma {params.gamma} --lambda {params.lam} --iterations {params.iterations} --precision {params.precision} --sampling {params.sampling} --screenSampling {params.screenSampling} {params.stripRownames} {params.mle_ps} {params.mle_ss} {params.hme} --mle-steps {params.mle_steps} --mle-iterations {params.mle_iterations} --mle-sampling {params.mle_sampling} 
        """

rule beast:
    input:
        "results/XMLs/{patient}.{stemcell}cells.xml"
    output:
        multiext("results/{patient}/{stemcell}cells/single_chains/{patient}.{stemcell}cells.seed{seed}", ".trees", ".log", ".ops", ".MLE.log")
    log: 
        "results/{patient}/beast_stdout/{patient}.{stemcell}cells.{seed}.stdout"
    params: 
        directory = "results/{patient}/{stemcell}cells/single_chains",
        seed = "{seed}",
    shell:
        """
        log=$(realpath {log})
        input=$(realpath {input})
        cd {params.directory} || (mkdir {params.directory} && cd {params.directory})
        mkdir {params.seed} && cd {params.seed}
        beast -beagle_off -seed {params.seed} $input > $log 2> $log
        ls | while read file
        do
            fn=$(echo $file | sed 's/cells/cells.seed{params.seed}/g')
            mv $file ../$fn
        done
        cd .. && rm -d {params.seed}
        """

       
rule logcombiner:
    input:
        trees = expand("results/{{patient}}/{{stemcell}}cells/single_chains/{{patient}}.{{stemcell}}cells.seed{seed}.trees", seed = seeds),
        logs = expand("results/{{patient}}/{{stemcell}}cells/single_chains/{{patient}}.{{stemcell}}cells.seed{seed}.log", seed = seeds)
    output:
        combined_logs = "results/{patient}/{stemcell}cells/combined_chains/{patient}.log",
        combined_trees = "results/{patient}/{stemcell}cells/combined_chains/{patient}.trees"
    params: 
        burnin = int(config.get("burnin", 0.1)*CHAIN_LENGTH)
    shell:
        """
        logcombiner -burnin {params.burnin} {input.logs} {output.combined_logs}
        logcombiner -burnin {params.burnin} -trees {input.trees} {output.combined_trees}
        """

rule logcombiner_one:  #case where only one chain was run
    input:
        trees = expand("results/{{patient}}/{{stemcell}}cells/single_chains/{{patient}}.{{stemcell}}cells.seed{seed}.trees", seed = seeds),
        logs = expand("results/{{patient}}/{{stemcell}}cells/single_chains/{{patient}}.{{stemcell}}cells.seed{seed}.log", seed = seeds),
    output:
        combined_logs = "results/{patient}/{stemcell}cells/combined_chains/{patient}.log",
        combined_trees = "results/{patient}/{stemcell}cells/combined_chains/{patient}.trees"
    shell:
        """
        grep -v ^"#" {input.logs} > {output.combined_logs}
        cp {input.trees} {output.combined_trees}
        """
        
rule qc:
    input:
        trees = expand("results/{{patient}}/{{stemcell}}cells/single_chains/{{patient}}.{{stemcell}}cells.seed{seed}.trees", seed = seeds),
        logs = expand("results/{{patient}}/{{stemcell}}cells/single_chains/{{patient}}.{{stemcell}}cells.seed{seed}.MLE.log", seed = seeds) if NCHAINS > 1 else expand("results/{{patient}}/beast_stdout/{{patient}}.{{stemcell}}cells.{seed}.stdout", seed = seeds),
    output:
        outdir = directory("results/{patient}/{stemcell}cells/convergence/"),
        mle = "results/{patient}/{stemcell}cells/convergence/{patient}.{stemcell}cells.MLE.csv", #Generated first with the first script and then we modify it with the second one
    params:
        burnin = config.get("burnin", 0.1),
        gather_MLE = gather_MLE,
        convergence = convergence,
        prefix = "{patient}.{stemcell}cells"
    shell:
        """
        Rscript {params.convergence} {params.burnin} {threads} {output.outdir}  {params.prefix} {input.trees} 2> /dev/null
        bash {params.gather_MLE} {input.logs} {output.mle}
        """

  
rule plot_tree:
    input:
        trees = rules.logcombiner.output.combined_trees if NCHAINS > 1 else rules.logcombiner_one.output.combined_trees, # cant user combiner when there's one single chain. results/{patient}/seed{seed}/{patient}.{stemcell}cells.trees
        log = rules.logcombiner.output.combined_logs if NCHAINS > 1 else rules.logcombiner_one.output.combined_logs # results/{patient}/seed{seed}/{patient}.{stemcell}cells.log
    output:
        consensus = "results/{patient}/{stemcell}cells/{patient}.consensus.tree",
        plot = "results/{patient}/{stemcell}cells/{patient}.tree.pdf"
    params:
        plot_tree = plot_tree,
        age = lambda wildcards: AGES[wildcards.patient],
        burnin = int(config.get("burnin", 0.1)) if NCHAINS == 1 else 0 # When multiple chains, logcombiner already applies the burnin
    shell:
        """
        treeannotator -heights median {input.trees} {output.consensus}
        Rscript {params.plot_tree} {output.consensus} '{wildcards.patient}' {output.plot} {params.burnin} {params.age} {input.log}
        """

rule plot_posterior:
    input:
        minibetas = lambda wildcards: aggregate_input(wildcards),
        log = rules.logcombiner.output.combined_logs if NCHAINS > 1 else rules.logcombiner_one.output.combined_logs,
        best = rules.qc.output.mle,
        patientInfo = config.get("patientInfo"),
    params:
        age_col = age_col,
        sample_col = sample_col,
        plot_post = plot_post,
    output: 
        "results/{patient}/{stemcell}cells/{patient}.posterior.pdf"
    shell:
        """
        python3 {params.plot_post} --betas {input.minibetas} --log {input.log} --samplesheet {input.patientInfo} --age_col {params.age_col} --sample_col {params.sample_col} --output {output}
        """

rule model_selection:
    input:
        mle = expand("results/{{patient}}/{stemcell}cells/convergence/{{patient}}.{stemcell}cells.MLE.csv", stemcell = STEMCELLS),
    output:
        folder = directory("results/{patient}/selected_model"),
        best = "results/{patient}/selected_model/{patient}.bestFitS.csv",
        allS = "results/{patient}/selected_model/{patient}.allS.csv",
    params:
        model_selection = model_selection,
        patient_dir = "results/{patient}/"
    shell:
        """
        Rscript {params.model_selection} {output.folder} {input.mle}
        optimal_cells=$(awk -F, 'NR==2 {{print $3}}' {output.best}) 
        ln -s $(realpath {params.patient_dir}/${{optimal_cells}}cells/) {output.folder}
        """


rule integrate_over_S:
    input:
        allS = "results/{patient}/selected_model/{patient}.allS.csv",
        trees = expand("results/{{patient}}/{stemcell}cells/combined_chains/{{patient}}.trees", stemcell=STEMCELLS)
    output:
        overS_tree = "results/{patient}/integrated_solution/{patient}.overS.trees",
        overS_log = "results/{patient}/integrated_solution/{patient}.overS.log"
    params:
        seed = SEED,
        integrate_over_S = integrate_over_S
    shell: 
        """
        Rscript {params.integrate_over_S} -o {output.overS_tree} -b 0 -s {params.seed} -m {input.allS} {input.trees}
        [[ -f {output.overS_log} ]]
        """

use rule plot_tree as plot_tree_integration with:
    input:
        trees = "results/{patient}/integrated_solution/{patient}.overS.trees",
        log = "results/{patient}/integrated_solution/{patient}.overS.log" 
    output:
        consensus = "results/{patient}/integrated_solution/{patient}.consensus.tree",
        plot = "results/{patient}/integrated_solution/{patient}.tree.pdf"

rule plot_posterior_integration:
    input:
        minibetas = lambda wildcards: aggregate_input(wildcards),
        log = "results/{patient}/integrated_solution/{patient}.overS.log",
        patientInfo = config.get("patientInfo"),
    params:
        age_col = age_col,
        sample_col = sample_col,
        plot_post = plot_post,
    output: 
        "results/{patient}/integrated_solution/{patient}.posterior.pdf"
    shell:
        """
        python3 {params.plot_post} --betas {input.minibetas} --log {input.log} --samplesheet {input.patientInfo} --age_col {params.age_col} --sample_col {params.sample_col} --output {output}
        """