import os
import sys
import yaml
import logging
from pathlib import Path

from snakemake.utils import min_version
min_version("7.32")

shell.executable("bash")

################################
# Global variables and configurations.
################################
assembly_path = config['Inputs']['assembly']
ALIGNMENT_FILES = config['Inputs']['alignment']
DEPTH_TEXT_FILE = config['Inputs']['depth_txt']
try:
    DEPTH_FILE_EXISTS = Path(DEPTH_TEXT_FILE).exists()
except TypeError:
    DEPTH_FILE_EXISTS = False

OUTPUT_DIR = Path(config['outputdir'])
BATCH_NAME = config['Inputs']['batch_name']

WORKFLOW_DIR = Path(workflow.basedir)
SCR_DIR = WORKFLOW_DIR/"scripts"
MMSEQSDB = config['mmseqs2']['database']
GLOBAL_PREFIX_contigs="out.mmseqs2"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# create temporary directory.
TMP_DIR = config['tmpdir']
if not os.path.isabs(TMP_DIR):
    TMP_DIR = OUTPUT_DIR/TMP_DIR
TMP_DIR.mkdir(parents=True, exist_ok=True) # what if we do want to prepare data and put it here? 

####################################
# Paths for binny, and some configs for binny.
####################################
BINNY_MANTIS_ENV = config['binning']['mantis_env']
# GLOBAL_CONFIG = "config.ChloroScan.yaml"

BINNY_DIR = WORKFLOW_DIR.parent / "binny_Chloroscan"
BINNY_SNAKEFILE = BINNY_DIR / "workflow" / "binny_core.smk"
BINNY_FINAL_OUTPUT_DIR = OUTPUT_DIR / "working/binny"
BINNY_DEPTH_INPUT = DEPTH_TEXT_FILE if DEPTH_FILE_EXISTS else str(OUTPUT_DIR / "working/depth_profile.txt")

BINNY_RUNTIME_DIR = OUTPUT_DIR / "working/binny_runtime"
BINNY_RUNTIME_MANTIS_CFG = BINNY_RUNTIME_DIR / "binny_mantis.cfg"
BINNY_NLTK_READY = BINNY_RUNTIME_DIR / "nltk.ready"
BINNY_NLTK_DIR = BINNY_RUNTIME_DIR / "nltk_data"

# BINNY_DB_ROOT = BINNY_DIR / "resources"
BINNY_DB_DIR = config['binning']['database'] # it must be provided.
BINNY_DB_SENTINEL = BINNY_DB_DIR / "hmms" / "markers" / "markers.hmm"

# Some configs are new for binny, add to snk.yml.
BINNY_CONFIG = {
    "mem": {
        "normal_mem_per_core_gb": 2,
        "big_mem_per_core_gb": 26,
        "big_mem_avail": 0,
    },
    "tmp_dir": str(TMP_DIR / "binny_tmp"),
    "raws": {
        "assembly": str(OUTPUT_DIR / "working/corgi/plastid.fasta"),
        "contig_depth": str(DEPTH_TEXT_FILE if DEPTH_FILE_EXISTS else OUTPUT_DIR / "working/depth_profile.txt"),
        "metagenomics_alignment": ""
    },
    "sample": BATCH_NAME,
    "outputdir": str(OUTPUT_DIR / "working/binny"),
    "db_path": str(BINNY_DB_DIR), # database should be setup intact. Also mantis shall be setup.
    "snakemake_env": "",
    "prokka_env": "",
    "mantis_env": BINNY_MANTIS_ENV,
    "conda_source": str(BINNY_DIR / ".conda"),
    "mask_disruptive_sequences": 'True',
    "extract_scmags": 'True',
    "coassembly_mode": "auto",
    "NX_value": 90,
    "min_cont_length_cutoff": config["binning"]["universal_length_cutoff"],
    "max_cont_length_cutoff": config["binning"]["universal_length_cutoff"],
    "min_cont_length_cutoff_marker": config["binning"]["universal_length_cutoff"],
    "max_cont_length_cutoff_marker": config["binning"]["universal_length_cutoff"],
    "max_n_contigs": 5e5,
    "max_marker_lineage_depth_lvl": 4,
    "distance_metric": "euclidean",
    "kmers": "2,3,4",
    "embedding": {
        "max_iterations": 50
    },
    "clustering": {
        "hdbscan_epsilon_range": config["binning"]["clustering"]["epsilon_range"],
        "hdbscan_min_samples_range": config["binning"]["clustering"]["hdbscan_min_sample_range"],
        "include_depth_initial": 'True',
        "include_depth_main": 'True',
    },
    "bin_quality": {
        "min_completeness": config["binning"]["bin_quality"]["min_completeness"],
        "start_completeness": config["binning"]["bin_quality"]["starting_completeness"],
        "purity": config["binning"]["bin_quality"]["purity"],
    },
    "write_contig_data": "True",
    "mantis_cfg": str(BINNY_RUNTIME_MANTIS_CFG),
    "nltk_ready": str(BINNY_NLTK_READY),
    "nltk_data_dir": str(BINNY_NLTK_DIR),
}

########################################
# Record the config for reproducibility.
########################################

CONFIG_RECORD=os.path.join(OUTPUT_DIR, "arguments.txt")
with open(CONFIG_RECORD, "w") as cr:
    cr.write("assembly path: {}\n".format(assembly_path))
    cr.write("abundance profile: {}\n".format(ALIGNMENT_FILES))
    cr.write("depth profile: {}\n".format(DEPTH_TEXT_FILE))
    cr.write("output dir: {}\n".format(OUTPUT_DIR))
    cr.write("contig classification settings:\n")
    cr.write("\tprobability threshold: {}\n".format(config['corgi']['pthreshold']))
    cr.write("\tminimum contig length cutoff: {}\n".format(config['corgi']['min_length']))
    cr.write("\tbatch size for prediction: {}\n".format(config['corgi']['batch_size']))
    cr.write("binning settings: # Note: by default binny automatically uses coassembly mode for the input metagenome.\n")
    cr.write("\tcontig length cutoff for clustering and marker gene detection: {}\n".format(config['binning']['universal_length_cutoff']))
    cr.write("\tHDBSCAN epsilon range: {}\n".format(config['binning']['clustering']['epsilon_range']))
    cr.write("\tHDBSCAN minimum sample range: {}\n".format(config['binning']['clustering']['hdbscan_min_sample_range']))
    cr.write("\tMAG starting completeness: {}\n".format(config['binning']['bin_quality']['starting_completeness']))
    cr.write("\tMAG minimum completeness: {}\n".format(config['binning']['bin_quality']['min_completeness']))
    cr.write("\tMAG purity (1-contamination) cutoff: {}\n".format(config['binning']['bin_quality']['purity']))
    cr.write("MMseqs2 database: {}\n".format(config['mmseqs2']['database']))


rule all:
    input:
        OUTPUT_DIR/"ChloroScan.done"

# The most important two rules are: corgi and binny.
# SCRDIR, needs an absolute path?
rule corgi_prediction:
    input: 
        seqs = assembly_path 
    params:
        batch_size = config['corgi']['batch_size'],
        min_length = config['corgi']['min_length'],
        p_threshold = config['corgi']['pthreshold'],
        bash_plastid_accession_script = Path(SCR_DIR/"corgi-gather_accession.sh"),
        python_fasta_writer = Path(SCR_DIR/"corgi-fasta_filter.py"),
        plastid_contigs_identifier = Path(TMP_DIR/"plastid_accession.txt"),
        output_dir=OUTPUT_DIR
    conda:
        "envs/corgi_env.yml"
    output:
        corgi_out = OUTPUT_DIR/"working/corgi/corgi-prediction.csv",
        plastid_contigs = OUTPUT_DIR/"working/corgi/plastid.fasta"
        # here may change the output to a directory for file writing!
    message:
        "Using the machine learning algorithm 'CORGI' to predict contig identity. Don't rise batch size too high."
    resources:
        mem_mb=20000
    log:
        OUTPUT_DIR/"logging_info/corgi.log"
    benchmark:
        OUTPUT_DIR/"logging_info/corgi_benchmarking.tsv"
    shell:
        """
        set -e
        touch {output.plastid_contigs}
        
        echo "new version: corgi --input {input.seqs} --output-csv {output.corgi_out} --batch-size {params.batch_size} --min-length {params.min_length}"
        corgi --input {input.seqs} --output-csv {output.corgi_out} --batch-size {params.batch_size} --min-length {params.min_length} &> {log}
        # old version: corgi --file {input.seqs} --csv {output.corgi_out} --batch-size {params.batch_size} --min-length {params.min_length} --no-save-filtered # &> {log}
        ls -lh {output.corgi_out}

        head -2 {output.corgi_out}
        
        header=1
        raw_plastid_contig_count=$(grep -o ",Plastid," {output.corgi_out} | wc -l)
        plastid_contig_count=$((raw_plastid_contig_count-header))
        N_SEQ=$(wc -l < {output.corgi_out})

        zero=0
        if [ $plastid_contig_count -eq $zero ]; then
            echo "No plastid contigs found, rule exit."
            exit 0

        elif [ $plastid_contig_count -le 100 ] && [ $plastid_contig_count -ge $((N_SEQ/2)) ]; then
            echo "bash {params.bash_plastid_accession_script} -c {output.corgi_out} -o {params.plastid_contigs_identifier}"
            bash {params.bash_plastid_accession_script} -c {output.corgi_out} -o {params.plastid_contigs_identifier}
            echo "python {params.python_fasta_writer} --assembly {input.seqs} --accession {params.plastid_contigs_identifier} --fasta_file_name {output.plastid_contigs}"
            python {params.python_fasta_writer} --assembly {input.seqs} --accession {params.plastid_contigs_identifier} --fasta_file_name {output.plastid_contigs}

            # Do some calculations here.
            prediction_lines=$(wc -l <{output.corgi_out})
            contig_total=$((prediction_lines-header))
            ls {params.output_dir}/working/corgi
            
            plastid_match=$(grep -o ">" {output.plastid_contigs} | wc -l)

            echo "total classified contigs containing all 5 domains: $contig_total"> {params.output_dir}/corgi.summary.txt
            echo "total number of plastid contigs: $plastid_match">>{params.output_dir}/corgi.summary.txt
        elif [ $plastid_contig_count -gt 100 ] && [ $plastid_contig_count -lt $((N_SEQ/2)) ]; then
            echo "More than 100 plastid contigs found, which may be an overprediction. Please check the {output.corgi_out} file for details. We will apply a more stringent threshold to gather plastid contigs."
            echo "bash {params.bash_plastid_accession_script} -c {output.corgi_out} -o {params.plastid_contigs_identifier} -p {params.p_threshold}"
            bash {params.bash_plastid_accession_script} -c {output.corgi_out} -o {params.plastid_contigs_identifier} -p {params.p_threshold}
            python {params.python_fasta_writer} --assembly {input.seqs} --accession {params.plastid_contigs_identifier} --fasta_file_name {output.plastid_contigs}
            # Do some calculations here.
            prediction_lines=$(wc -l <{output.corgi_out})
            contig_total=$((prediction_lines-header))
            
            plastid_match=$(grep -o ">" {output.plastid_contigs} | wc -l)

            echo "total classified contigs containing all 5 domains: $contig_total"> {params.output_dir}/corgi.summary.txt
            echo "total number of plastid contigs: $plastid_match">>{params.output_dir}/corgi.summary.txt
        else
            echo "bash {params.bash_plastid_accession_script} -c {output.corgi_out} -o {params.plastid_contigs_identifier} -p {params.p_threshold}"
            bash {params.bash_plastid_accession_script} -c {output.corgi_out} -o {params.plastid_contigs_identifier} -p {params.p_threshold}
            python {params.python_fasta_writer} --assembly {input.seqs} --accession {params.plastid_contigs_identifier} --fasta_file_name {output.plastid_contigs}
            # Do some calculations here.
            prediction_lines=$(wc -l <{output.corgi_out})
            contig_total=$((prediction_lines-header))
            
            plastid_match=$(grep -o ">" {output.plastid_contigs} | wc -l)

            echo "total classified contigs containing all 5 domains: $contig_total"> {params.output_dir}/corgi.summary.txt
            echo "total number of plastid contigs: $plastid_match">>{params.output_dir}/corgi.summary.txt
        fi
        
        """    

if not DEPTH_FILE_EXISTS:
    rule contig_depth_calculation:
        input:
            assembly = rules.corgi_prediction.output.plastid_contigs,
            alignment = ALIGNMENT_FILES
        output:
            directory(OUTPUT_DIR/"working/call_contig_depth")
        params:
            script_dir = SCR_DIR,
            tmpdir = TMP_DIR,
            perl_script = SCR_DIR/"calcAvgCoverage.pl"
        conda:
            "envs/DepthCall.yml"
        message:
            "Calculate the depth of each contig in the assembly."
        log:
            OUTPUT_DIR/"logging_info/call_contig_depth.log"
        script:
            "{params.script_dir}/call_contig_depth.sh"

    rule merge_contig_depth:
        input:
            depth_files = OUTPUT_DIR/"working/call_contig_depth"
        output:
            OUTPUT_DIR/"working/depth_profile.txt"
        params:
            script_dir = SCR_DIR
        conda:
            "envs/DepthCall.yml"
        message:
            "Merge the depth of each contig in the assembly."
        script:
            "{params.script_dir}/merge_contig_depth.sh"

# declare binny as a module, no longer use separate snakemake anymore.
# rule ensure_binny_database:
#     output:
#         sentinel=BINNY_DB_SENTINEL
#     params:
#         db_root=BINNY_DB_ROOT,
#         article_id="32121661",
#         token="~/.figshare/token"
#     conda:
#         "envs/prepare_binny_db.yml"
#     log:
#         OUTPUT_DIR / "logging_info" / "binny_database_download.log"
#     shell:
#         r"""
#         set -euo pipefail

#         if [ ! -f {params.token} ]; then
#             echo "Figshare token file not found at {params.token}. Please create your own figshare account and create a file containing your Figshare token at this location."
#             exit 1
#         fi

#         mkdir -p {params.db_root}/A2K_database

#         if [ -s {output.sentinel} ]; then
#             echo "Binny database already present at {params.db_root}/A2K_database" >> {log}
#             exit 0
#         fi

#         figshare download {params.article_id} -o {params.db_root}/A2K_database >> {log} 2>&1
        
#         touch {output.sentinel}
#         """

rule prepare_binny_mantis_cfg:
    input:
        db_ready=BINNY_DB_SENTINEL
    output:
        cfg=BINNY_RUNTIME_MANTIS_CFG
    params:
        template=BINNY_DIR / "config/binny_mantis.cfg",
        hmm=BINNY_DB_SENTINEL
    run:
        import pathlib
        text = pathlib.Path(params.template).read_text()
        text = text.replace(
            "/mnt/lscratch/users/ohickl/binning/tools/binny_devel/database/hmms/checkm_pf/checkm_filtered_pf.hmm",
            str(params.hmm),
        )
        text = text.replace(
            "custom_ref=/mnt/lscratch/users/ohickl/binning/tools/binny_devel/database/hmms/checkm_tf/checkm_filtered_tf.hmm\n",
            "",
        )
        text = text.replace("checkm_filtered_tf_weight=0.5\n", "")
        pathlib.Path(output.cfg).parent.mkdir(parents=True, exist_ok=True)
        pathlib.Path(output.cfg).write_text(text)
    
module binny_mod:
    snakefile:
        str(BINNY_SNAKEFILE)
    config:
        BINNY_CONFIG

# modularized binny rules.   
use rule prepare_input_data from binny_mod as binny_prepare_input_data
use rule format_assembly from binny_mod as binny_format_assembly
use rule annotate from binny_mod as binny_annotate
use rule prepare_mantis from binny_mod as binny_prepare_mantis
use rule mantis_checkm_marker_sets from binny_mod as binny_mantis_checkm_marker_sets
use rule binny from binny_mod as binny_run

rule check_binny_output:
    input:
        binny_out=rules.binny_run.output,
    params:
        binny_bins=Path(OUTPUT_DIR/"working/binny/bins"),
    output:
        check_flag=OUTPUT_DIR/"working/binny/binny_output_check.txt"
    log:
        OUTPUT_DIR/"logging_info/check_binny_output.log"
    shell:
        """
        set -e
        if [ -d "{params.binny_bins}" ] && [ "$(ls -A {params.binny_bins})" ]; then
            echo "Binny produced bins successfully." > {output.check_flag}
        else
            echo "Binny did not produce any bins."
            exit 1
        fi
        """

rule MMseqs2_taxonomy:
    input: 
        flag=rules.check_binny_output.output.check_flag,
        binny_output=rules.binny_run.output
    params:
        prefix=GLOBAL_PREFIX_contigs,
        path_to_contigs=os.path.join(OUTPUT_DIR, "working", "binny", "intermediary", "assembly.formatted.fa")
    output:
        mmseqs2=directory(OUTPUT_DIR/"working/mmseqs2"),
        krona_report=OUTPUT_DIR/"Krona.html",
    log:
        OUTPUT_DIR/"logging_info/mmseqs2.log"
    benchmark:
        OUTPUT_DIR/"logging_info/mmseqs2_benchmark.tsv"
    conda:
        "envs/mmseqs.yml"
    threads:
        os.cpu_count()
    message:
        "Run MMseqs2 taxonomy classification for the plastid contigs."
    shell:
        """
        set -e
        mkdir -p {output.mmseqs2}
        resource="$(( {threads} * 10 ))G"
        # first, create contigdb based on the assembly.
        mmseqs createdb {params.path_to_contigs} {output.mmseqs2}/contigdb
        # then, run taxonomy classification.
        mmseqs taxonomy {output.mmseqs2}/contigdb {MMSEQSDB} {output.mmseqs2}/{params.prefix} {output.mmseqs2}/tmp --threads {threads} \
            --search-type 2 --tax-lineage 2 --split-memory-limit $resource > {log} 2>&1
        # extract results as tsv.
        mmseqs createtsv {output.mmseqs2}/contigdb {MMSEQSDB} {output.mmseqs2}/{params.prefix} {output.mmseqs2}/{params.prefix}.tsv >> {log} 2>&1
        # generate krona report.
        mmseqs taxonomyreport {MMSEQSDB} {output.mmseqs2}/{params.prefix} {output.krona_report} --threads {threads} --report-mode 1 >> {log} 2>&1
        """

# this is the correct one.
rule documenting_binny_results:
    input: 
        contig_level_annotation = OUTPUT_DIR/"working/mmseqs2"
    params: 
        bins = Path(OUTPUT_DIR/"working/binny/bins"),
        assembly_files = OUTPUT_DIR/"working/binny/intermediary/assembly.formatted.fa",
        assembly_depth = OUTPUT_DIR/"working/binny/intermediary/assembly.contig_depth.txt",
        marker_gene_gff = OUTPUT_DIR/"working/binny/intermediary/annotation_CDS_RNA_hmms_checkm.gff",
        MMA_summary=OUTPUT_DIR/"MMA.summary.txt",
        ANNOT_FILE="out.mmseqs2.tsv",
        script_dir = SCR_DIR
    output:
        OUTPUT_DIR/"working/summary/cross_ref.tsv",
    benchmark:
        OUTPUT_DIR/"logging_info/summarize_binny.tsv"
    conda:
        "envs/documenting_binny.yml"
    message:
        "Run python script summarize_binny_results.py to record each bin's contig metadata."
    shell:
        """
        $CONDA_PREFIX/bin/python {params.script_dir}/summarize_binny_results.py --assemblyfasta {params.assembly_files} \
            --assemblydepth {params.assembly_depth} --assembly_annot {params.marker_gene_gff} \
            --contig_level_annotation {input.contig_level_annotation}/{params.ANNOT_FILE} \
            --assembly_bin_dir {params.bins} --MMA_SUMMARY {params.MMA_summary} --output {output}
        """

rule refine_bins:
    input:
        cross_ref = OUTPUT_DIR/"working/summary/cross_ref.tsv",
        original_bins = Path(OUTPUT_DIR/"working/binny"),
        MMseqs2_prediction = Path(OUTPUT_DIR/"working/mmseqs2"),
    params:
        script_dir = SCR_DIR
    output:
        refine_dir = directory(OUTPUT_DIR/"working/refined_bins"),
        summary_info = OUTPUT_DIR/"working/refinement_contig_summary.txt"
    conda:
        "envs/refinement.yml"
    message:
        "Identify potential contaminations in each plastid MAG. Users can decide to use the refined or binny-generated bins."
    log:
        OUTPUT_DIR/"logging_info/refinement.log"
    shell:
        """
        $CONDA_PREFIX/bin/python {params.script_dir}/refine_bins.py --input_cross_ref {input.cross_ref} \
            --input_original_bins {input.original_bins} \
            --input_MMseqs2_prediction {input.MMseqs2_prediction} \
            --output_dir_for_refined_bins {output.refine_dir} \
            --output_summary_info {output.summary_info} \
            --log_file {log}
        """

rule visualize_results:
    input:
        OUTPUT_DIR/"working/summary/cross_ref.tsv",
    params:
        batch_name=BATCH_NAME,
        script_dir = SCR_DIR,
        refine_bins_dir=OUTPUT_DIR/"working/binny",
    output:
        directory(OUTPUT_DIR/"working/visualizations")
    conda:
        "envs/visualization.yml"
    log:
        OUTPUT_DIR/"logging_info/visualizations.log"
    shell:
        """
        $CONDA_PREFIX/bin/python {params.script_dir}/visualization.py --input_cross_ref {input} \
            --batch_name {params.batch_name} \
            --refine_bins_dir {params.refine_bins_dir} \
            --output_dir {output} >> {log} 2>&1
        """

rule cds_extraction:
    input:
        binny_done=rules.binny_run.output[-1],
    params:
        batch_name=BATCH_NAME,
        gff_file_flag="GFFs",
        script_dir = SCR_DIR,
        path_fraggenescan = OUTPUT_DIR/"working/FragGeneScanRs",
    output:
        CDS_EXTRACTION=directory(OUTPUT_DIR/"working/cds-extraction")
    message:
        "Use the fast ORF-predictor FragGeneScan Rust to predict gene contents of the MAGs, then extract the cds via gffread."
    shell:
        """
        mkdir -p {params.path_fraggenescan}/{params.gff_file_flag}
        mkdir -p {output}
        mkdir -p {output}/cds
        mkdir -p {output}/faa
        export PATH=$PATH:$HOME/.cargo/bin

        BINNY_DIR=$(dirname {input.binny_done})/bins
        echo $BINNY_DIR
        BATCH_NAME={params.batch_name}

        for i in $(ls $BINNY_DIR);
        do 
            echo $BINNY_DIR/$i
            FragGeneScanRs --seq-file-name $BINNY_DIR/$i --training-file illumina_5 --thread-num 1 -g {params.path_fraggenescan}/{params.gff_file_flag}/$(basename $i .fasta).gff \
            -n {output}/cds/$(basename $BATCH_NAME).$(basename $i .fasta).gene.fasta -a {output}/faa/$(basename $BATCH_NAME).$(basename $i .fasta).faa -w 0; 
        done

        if [ $? -ne 0 ]; then
            echo "Error: FragGeneScanRs failed to run."
            exit 1
        else
            echo "FragGeneScanRs completed successfully."
        fi
        """

rule check_done:
    input:
        OUTPUT_DIR/"working/visualizations",
        OUTPUT_DIR/"working/refined_bins",
        OUTPUT_DIR/"working/binny",
        OUTPUT_DIR/"working/mmseqs2",
        OUTPUT_DIR/"working/cds-extraction",
        OUTPUT_DIR/"working/summary/cross_ref.tsv",
        rules.refine_bins.output.refine_dir,
        OUTPUT_DIR/"Krona.html",
        # anvio_flag = ANVIO_PREPARE
    log:
        OUTPUT_DIR/"logging_info/check_done.log"
    output:
        OUTPUT_DIR/"ChloroScan.done"
    shell:
        """
        # check if all files are present, if yes, then touch the done file.
        flag=0
        for f in {input}; do
            if [ ! -s $f ]; then
                flag=1
                break
            fi
        done

        echo "Flag is $flag" >> {log}

        if [ $flag -eq 0 ]; then
            touch {output}
        fi
        """