"""
Steps:
1. Create base environment:
    conda env create -f ../envs/create_metapackage.yml -n lyrebird-metapackage
2. Install vcontact3 and download the latest vcontact3 database
    conda env create -f /envs/vcontact3.yml -n vcontact3
    conda activate vcontact3
    vcontact3 prepare_databases --get-version 230 --set-location ./db (or use --get-version latest)
3. Update the variables vcontact_env and vcontact_db_path below with your conda environment name and database path
4. Download metaVR metadata and genomes, and update the variables (this is a >77GB download, ensure you have enough space)
    wget https://www.meta-virome.org/DownloadUvigMetadata
    wget https://www.meta-virome.org/DownloadUvigFasta
5. Run the Snakefile:
    conda activate lyrebird-metapackage
    snakemake --use-conda --conda-frontend mamba --cores 64 --profile aqua --directory <output_directory> --conda-prefix <path_to_conda_envs>  
"""
import os
from Bio import SeqIO

vcontact_env = "" # Specify your pre-installed vcontact3 conda environment
vcontact_db_path = "" # Specify your downloaded vcontact3 database path
metavr_metadata_fp = "/mnt/hpccs01/work/microbiome/msingle/rossenzhao/metaVR/uvig_metadata.tsv.gz" # Specify your downloaded metaVR metadata file path
metavr_genomes_fp = "/mnt/hpccs01/work/microbiome/msingle/rossenzhao/metaVR/IMGVR5_UViG.fna.gz" # Specify your downloaded metaVR genomes file path

localrules: order_genomes_fps, process_taxonomy

rule all:
    input:
        "prodigal_gv.done",
        # "vcontact3.done",
        #TODO: fill in final output files

rule get_ictv_genomes:
    params:
        url = 'https://ictv.global/vmr/current'
    output:
        done = touch("ictv/done"),
        outdir = directory("ictv")
    log:
        "logs/get_ictv_genomes.log"
    resources:
        mem_mb = 20 * 1024,
    conda:
        "envs/ictv-download.yml"
    shell:
        "python {workflow.basedir}/scripts/ictv-download.py --url {params.url} --outdir {output.outdir} &> {log}"

rule get_metavr_genomes:
    input:
        metadata_fp = metavr_metadata_fp,
        genomes_fp = metavr_genomes_fp,
    output:
        done = touch("metavr/done"),
        outdir = directory("metavr")
    log:
        "logs/get_metavr_genomes.log"
    resources:
        mem_mb = 128 * 1024,
    conda:
        "envs/ictv-download.yml"
    shell:
        "python {workflow.basedir}/scripts/metavr-download.py "\
        "--metadata-file {input.metadata_fp} " \
        "--genome-file {input.genomes_fp} " \
        "--outdir {output.outdir} &> {log}"

rule order_genomes_fps:
    input:
        ictv_done = "ictv/done",
        ictv_dir = "ictv",
        metavr_done = "metavr/done",
        metavr_dir = "metavr"
    params:
        ictv_metadata = "ictv/genome_metadata.tsv",
        metavr_metadata = "metavr/metavr_filtered_metadata.tsv"
    output:
        done = touch("ordered_genomes.done"),
        ordered_filepaths = "ordered_genome_filepaths.txt"
    run:
        from os.path import join
        from csv import DictReader
        # order ictv first, then metavr by length descending
        ictv_fps = []
        with open(params.ictv_metadata, 'r') as infile:
            reader = DictReader(infile, delimiter='\t')
            for row in reader:
                genome_fp = join(input.ictv_dir, "genomes", f"{row['genome']}.fna")
                ictv_fps.append( (int(row['length']), genome_fp) )
        metavr_fps = []
        with open(params.metavr_metadata, 'r') as infile:
            reader = DictReader(infile, delimiter='\t')
            for row in reader:
                genome_fp = join(input.metavr_dir, "genomes", f"{row['uvig']}.fna")
                metavr_fps.append( (int(row['length']), genome_fp) )
        # sort by length descending
        ictv_fps.sort(key=lambda x: x[0], reverse=True)
        metavr_fps.sort(key=lambda x: x[0], reverse=True)
        with open(output.ordered_filepaths, 'w') as outfile:
            for _, fp in ictv_fps:
                outfile.write(f"{fp}\n")
            for _, fp in metavr_fps:
                outfile.write(f"{fp}\n")

rule galah_cluster:
    input:
        genome_filepaths = "ordered_genome_filepaths.txt",
        ordered_genomes_done = "ordered_genomes.done"
    output:
        done = touch("galah_clustering.done"),
        cluster_rep_tsv = "galah_cluster_representatives.tsv",
        cluster_dir = directory("galah_clusters")
    log:
        "logs/galah_clustering.log"
    threads: 64
    resources:
        mem_mb = 256 * 1024,
    conda:
        "envs/galah.yml"
    shell: # use finch preclustering for speed
        "galah cluster --genome-fasta-list {input.genome_filepaths} --threads {threads} "\
        "--ani 95 --min-aligned-fraction 85 --fragment-length 500 "\
        "--precluster-method finch --precluster-ani 95 "\
        "--output-cluster-definition {output.cluster_rep_tsv} "\
        "--output-representative-fasta-directory {output.cluster_dir} "\
        "-v &> {log}"

rule prodigal_gv:
    input:
        cluster_dir = "galah_clusters"
    params:
        log_dir = "logs/prodigal_gv"
    output:
        done = touch("prodigal_gv.done"),
        transcripts_dir = directory("prodigal_transcripts"),
        proteins_dir = directory("prodigal_proteins")
    threads: 64
    resources:
        mem_mb = 500 * 1024,
    conda:
        "envs/prodigal-gv.yml"
    shell: # use GNU parallel to speed this up
        "mkdir -p {params.log_dir} ; " \
        "mkdir -p {output.transcripts_dir} ; " \
        "mkdir -p {output.proteins_dir} ; " \
        "find {input.cluster_dir} -name '*.fasta' | " \
        "parallel -j {threads} --will-cite -- 'prodigal-gv -i {{}} -a {output.proteins_dir}/{{/}}_protein.faa -d {output.transcripts_dir}/{{/}}_transcript.fna -p meta &> {params.log_dir}/{{/}}.log'" \

rule prepare_for_vcontact3:
    input:
        galah_dir = "galah_clusters",
        galah_done = "galah_clustering.done",
    output:
        done = touch("vcontact3_prepared.done"),
        concat_genomes = "vcontact3_concat_genomes.fna",
    shell: # there are many files, so use cat with find
        "find {input.galah_dir} -name '*.fna' -exec cat {{}} + > {output.concat_genomes}"

rule vcontact3:
    input:
        genomes_fna = "vcontact3_concat_genomes.fna",
        vcontact_prep_done = "vcontact3_prepared.done"
    params:
        vcontact3_db = vcontact_db_path
    output:
        done = touch("vcontact3.done"),
        vcontact_output_dir = directory("vcontact3_output")
    threads: 64
    resources:
        mem_mb = 500 * 1024,
    log:
        "logs/vcontact3.log"
    conda:
        vcontact_env
    shell:
        "vcontact3 run -n {input.genomes_fna} "\
        "-d {params.vcontact3_db} "\
        "-t {threads} "\
        "-o {output.vcontact_output_dir} "\
        "--pyrodigal-gv "\
        "&> {log}"

rule process_taxonomy: #TODO: script for processing vcontact3 taxonomy
    input:
        vcontact_output_dir = "vcontact3_output",
        vcontact_done = "vcontact3.done"
    params:
        ictv_metadata = "ictv/genome_metadata.tsv",
        metavr_metadata = "metavr/metavr_filtered_metadata.tsv"
    output:
        done = touch("vcontact3_taxonomy_processed.done"),
        taxonomy_tsv = "final_reconstructed_metadata.tsv"
    resources:
        mem_mb = 12 * 1024,
    log:
        "logs/process_vcontact3_taxonomy.log"
    script:
        "{workflow.basedir}/scripts/process_vcontact3_taxonomy.py"

