#!/usr/bin/env python

# export SKYWAYROOT=/project/rcc/trung/skyway-github
#./skyway_batch job_script.sh

import argparse
from datetime import datetime, timezone
from io import StringIO
import os
import sys
from subprocess import PIPE, Popen

import colorama
from colorama import Fore

import skyway
from skyway.cloud.aws import *
from skyway.cloud.gcp import *
from skyway.cloud.azure import *
from skyway.cloud.oci import *
from skyway.cloud.slurm import *

class InstanceDescriptor:
    def __init__(self, jobname: str, account_name: str, node_type: str, walltime: str, vendor_name: str):
        self.jobname = jobname
        self.account_name = account_name
        self.node_type = node_type
        self.walltime = walltime
        self.vendor_name = vendor_name

        self.account = None
        if 'aws' in vendor_name:
            self.account = AWS(account_name)
        elif 'gcp' in vendor_name:
            self.account = GCP(account_name)
        elif 'azure' in vendor_name:
            self.account = AZURE(account_name)
        elif 'oci' in vendor_name:
            self.account = OCI(account_name)
        elif 'midway3' in vendor_name:
            self.account = SLURMCluster(account_name)

        self.user = os.environ['USER']

    def submitJob(self, script_name=None, pre_execute=""):
        print(Fore.BLUE + f"Requesting node from {self.vendor_name} with account {self.account_name}")
        nodes = self.account.create_nodes(self.node_type, [self.jobname], need_confirmation=False, walltime=self.walltime)

        # execute pre-execute commands after nodes are available: e.g. data transfers
        if pre_execute != "":
            # append the command with account and job name
            pre_execute += " --account=" + account_name + " -J" + job_name
            os.system(pre_execute)

        # execute the commands listed in the script on the compute node
        if script_name is not None:
            if "midway3" in self.vendor_name:
                # for on-premises like midway3 instanceID is the host ip (which happens to be the node name)
                instanceID = self.account.get_host_ip(self.jobname)
                self.account.execute_script(instanceID, script_name)

            elif "aws" in self.vendor_name or "oci" in self.vendor_name:
                instanceID = self.account.get_instance_ID(self.jobname)
                self.account.execute_script(instanceID, script_name)

            elif "gcp" in self.vendor_name:
                instanceID = self.account.get_instance_ID(self.jobname)
                self.account.execute_script(instanceID, script_name)

        return nodes

'''
   parse the job script to get the account information, node type (constraint) and walltime
'''
def parse_script(filename):
    jobname = "my-run"
    account = ""
    constraint = ""
    walltime = ""
    skyway_cmd = ""
    with open(filename, 'r') as f:
        lines = f.readlines()
        for line in lines:
            # extract only lines with #SBATCH, remove \n characters
            # remove all the spaces
            # split at '='
            if "#SBATCH" in line:
                # remove #SBATCH
                line = line.replace('#SBATCH ','').strip('\n')
                line = line.replace(' ','')
                args = line.split('=')
                if len(args) == 2:
                    if args[0] == "-job-name":
                        jobname = args[1]
                    if args[0] == "--account":
                        account = args[1]
                    if args[0] == "--constraint":
                        constraint = args[1]
                    if args[0] == "--time":
                        walltime = args[1]                        
            elif "skyway_" in line:
                skyway_cmd = line.strip('\n')
            else:
                continue

    return { 'jobname': jobname,
             'account': account,
             'constraint': constraint,
             'walltime': walltime,
             'skyway_cmd': skyway_cmd
            }

if __name__ == "__main__":

    colorama.init(autoreset=True)

    msg = "Skyway CLI batch job submisssion"
    script = sys.argv[1]
    args = parse_script(script)

    job_name = args['jobname']
    account_name = args['account']
    node_type = args['constraint']
    walltime = args['walltime']
    skyway_cmd = args['skyway_cmd']
    provider = ""

    if provider == "":
        # try to infer the vendor name from account
        if 'aws' in account_name:
            vendor_name = "aws"
        elif 'gcp' in account_name:
            vendor_name = "gcp"
        elif 'azure' in account_name:
            vendor_name = "azure"
        elif 'oci' in account_name:
            vendor_name = "oci"
        elif 'midway3' in account_name or 'rcc-staff' in account_name:
            vendor_name = "rcc-midway3"
        else:
            print(f"account name {account_name} is not valid")
            quit
    else:
        vendor_name = provider

    # get the env variable SKYWAYROOT
    skywayroot = os.environ['SKYWAYROOT']
    if skywayroot == "":
        raise Exception("SKYWAYROOT is not defined.")

    # create an instance descriptor
    instanceDescriptor = InstanceDescriptor(job_name, account_name, node_type, walltime, vendor_name)

    # submit job
    instanceDescriptor.submitJob(script_name=script, pre_execute=skyway_cmd)
