################################################################################
# VariantFoldRNA: A Snakemake Pipeline for RiboSNitch Prediction
#
# Author: Kobie Kirven
# Email: kjk6173@psu.edu
#
# Assmann and Bevilacqua Labs
# The Pennsylvania State University
################################################################################

# Import the python modules:
import os
import subprocess
import sys

# print(str(config['csv']))
# if the user does not specify the number of chunks, let's estimate it by the size of 
# the input file
current_directory = os.getcwd()

if config['out_dir'] == 'none':
    config['out_dir'] = current_directory + "/variantfoldrna_output"

if config['tmp_dir'] == 'none':
    config['tmp_dir'] = current_directory + "/variantfoldrna_tmp"

if config['chunks'] == -1 and str(config['csv']) == 'none':
    # get the size of the input file
    size = os.path.getsize(config['vcf'])
    # estimate the number of chunks
    config['chunks'] = int(size / 10000000) + 1
    
    # Print in blue the number of chunks that will be used
    print("\033[94m\nThe number of chunks to be used is based on the size of your VCF: " + str(config['chunks']) + "\033[0m")

elif config['chunks'] == -1 and str(config['vcf']) == 'none':
    size = os.path.getsize(config['csv'])
    # estimate the number of chunks
    config['chunks'] = int(size / 10000000) + 1
    
    # Print in blue the number of chunks that will be used
    print("\033[94m\nThe number of chunks to be used is based on the size of your CSV: " + str(config['chunks']) + "\033[0m")

# -- Import common rules -- #
if config['csv'] == "none":
    include: "rules/vep.smk"
    include: "rules/vcf_header.smk"
    include: "rules/combine_predictions.smk"
    include: "rules/chunk.smk"

        # Check to make sure that, if the user is running RNAsnp, they use 
    # a flanking lenth that is between 100 and 800 nt in intervals of 50
    if 'rnasnp:p-value' in config["ribosnitch_prediction_tool"].lower().split(",") or 'rnasnp:dist' in config["ribosnitch_prediction_tool"].lower().split(","):
        flanking = config['flank_len']
        if flanking*2 < 100 or flanking*2 > 800 or flanking*2 % 50 != 0:
            # Print in red that the flanking length is not within the acceptable range
            print("\033[91m\nError: The flanking length for RNAsnp must be between 100 and 800 nt in intervals of 50.\033[0m")
            # Exit the program
            os._exit(1)
        if "@" in str(config['temperature']) or float(config['temperature']) != 37.0:
            # Print that RNAsnp only will predict riboSNitches at 37 degrees C. 
            print("\033[91m\nError: RNAsnp only predicts riboSNitches at 37˚C\033[0m")
            # Exit the program
            os._exit(1)
            
    # Load the rules for the appropriate prediction tool
    if "snpfold" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/snpfold.smk"

    if "riprap" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/riprap.smk"

    if "rnasnp:p-value" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/rnasnp.smk"

    if "rnasnp:dist" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/rnasnp.smk"

    if "remurna" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/remurna.smk"

    for x in config["ribosnitch_prediction_tool"].lower().split(","):
        if x not in ["snpfold","riprap","rnasnp:p-value","rnasnp:dist","remurna"]:
            print(f"\033[91m\nError: The ribosnitch prediction tool {x} is not recognized. Please choose from SNPfold, RipRap, remuRNA, rnasnp:p-value, and rna-snp:dist\n")
            os._exit(1)

else:
    include: "rules/combine_predictions_csv.smk"
    include: "rules/chunk.smk"


        # Load the rules for the appropriate prediction tool
    if "snpfold" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/csv_rules_snpfold.smk"

    if "riprap" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/csv_rules_riprap.smk"

    if "rnasnp:p-value" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/csv_rules_rnasnp.smk"
        if "@" in str(config['temperature']) or float(config['temperature']) != 37.0:
            # Print that RNAsnp only will predict riboSNitches at 37 degrees C. 
            print("\033[91m\nError: RNAsnp only predicts riboSNitches at 37˚C\033[0m")
            # Exit the program
            os._exit(1)

    if "rnasnp:dist" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/csv_rules_rnasnp.smk"
        if "@" in str(config['temperature']) or float(config['temperature']) != 37.0:
            # Print that RNAsnp only will predict riboSNitches at 37 degrees C. 
            print("\033[91m\nError: RNAsnp only predicts riboSNitches at 37˚C\033[0m")
            # Exit the program
            os._exit(1)

    if "remurna" in config["ribosnitch_prediction_tool"].lower().split(","):
        include: "rules/csv_rules_remurna.smk"

    for x in config["ribosnitch_prediction_tool"].lower().split(","):
        if x not in ["snpfold","riprap","rnasnp:p-value","rnasnp:dist","remurna"]:
            print(f"\033[91m\nError: The ribosnitch prediction tool {x} is not recognized. Please choose from SNPfold, RipRap, remuRNA, rnasnp:p-value, and rna-snp:dist\n")
            os._exit(1)
    
########################################################
# If the user would like the spliced cDNA, include the
# appropriate rules
########################################################
if config["spliced"] == True:
    include: "rules/extract_seqs_spliced.smk"
else:
    include: "rules/extract_seqs.smk"


###########################################################
# If the user would like to perform riboSNitch prediction
# over a range of temperatures, adjust the output
# accordingly
###########################################################
if "@" in str(config["temperature"]):
    start_stop = str(config["temperature"]).split("@")
    if float(start_stop[1]) < float(start_stop[0]):
        print("\033[91m\nError: The ending temperature is less than the starting temperature\n")
        os._exit(1)
    temperature_range = [float(start_stop[0])]

    # Calculate the number of total temperature steps there will be
    steps = abs(round((float(start_stop[1]) - float(start_stop[0])) / float(config["temp_step"]))) + 1

    for x in range(steps):
        if round((temperature_range[-1] + float(config["temp_step"])),3) <=  round((temperature_range[-1] + float(config["temp_step"]))):
            temperature_range.append(round(temperature_range[-1]+ float(config["temp_step"]),4))
else:
    temperature_range = [config["temperature"]]

#################################################
# Print an update as to what is to be predicted
#################################################
print("\033[94mWill predict riboSNitches at the following temperatures:", ",".join([str(x) for x in temperature_range])+ "\033[0m")


#############################################
# Create the final input list
#############################################
final_input = []

if config['csv'] != "none":
    for temp in temperature_range:
        final_input.append(f"{config['out_dir']}/ribosnitch_predictions_csv/combined_ribosnitch_prediction_{temp}.txt")

else:
    for temp in temperature_range:
        # If the user just wants the riboSNitch predictions without the null predictions
        final_input.append(
            f"{config['out_dir']}/ribosnitch_predictions/combined_ribosnitch_prediction_{temp}.txt"
        )


#######################
#  --  Pipeline  --   #
#######################
rule all:
    input:
        final_input,
