#!/usr/bin/env python
# export SKYWAYROOT=/project/rcc/trung/skyway-github
#./skyway_interactive.py --account=rcc-aws --constraint=t1 --walltime=01:00:00

import argparse
from datetime import datetime, timezone
from io import StringIO
import os
import sys
import subprocess
from subprocess import PIPE, Popen

import colorama
from colorama import Fore

import skyway
from skyway.cloud.aws import *
from skyway.cloud.gcp import *
from skyway.cloud.azure import *
from skyway.cloud.slurm import *
from skyway.cloud.oci import *

class InstanceDescriptor:
    def __init__(self, jobname: str, account_name: str, node_type: str, walltime: str, vendor_name: str):
        self.jobname = jobname
        self.account_name = account_name
        self.node_type = node_type
        self.walltime = walltime
        self.vendor_name = vendor_name

        self.account = None
        if 'aws' in vendor_name:
            self.account = AWS(account_name)
        elif 'gcp' in vendor_name:
            self.account = GCP(account_name)
        elif 'azure' in vendor_name:
            self.account = AZURE(account_name)
        elif 'oci' in vendor_name:
            self.account = OCI(account_name)            
        elif 'midway3' in vendor_name:
            self.account = SLURMCluster(account_name)

        self.user = os.environ['USER']

    def submitJob(self):

        print(Fore.BLUE + f"Requesting instances (nodes) from {self.vendor_name} with account {self.account_name}")

        nodes = self.account.create_nodes(self.node_type,
                                          [self.jobname],
                                           need_confirmation=True,
                                           walltime=self.walltime,
                                           interactive=True)

        # connect to the node with jobname: used to work as expected
        #    21Aug2024: create_nodes with interactive==False would skip over the connectJob step below, no idea why
        #    27Sep2024: connectJob back to normal
        if nodes:
            instanceDescriptor.connectJob(node_names=[job_name])

        return nodes

    def connectJob(self, node_names):
        '''
        only get on a on-premise compute node for now with rcc-staff
        '''
        
        if "midway3" in self.vendor_name:
            instanceID = self.account.get_host_ip(self.jobname)
            self.account.connect_node(instanceID)

        elif "aws" in self.vendor_name or "gcp" in self.vendor_name:
            instanceID = self.account.get_instance_ID(self.jobname)
            self.account.connect_node(instanceID)

        elif "oci" in self.vendor_name:
            instanceID = self.account.get_instance_ID(self.jobname)
            self.account.connect_node(instanceID)    

    def terminateJob(self, node_names = []):
        if "midway3" in self.vendor_name:
            instanceID = self.account.get_instance_ID(self.jobname)
            self.account.destroy_nodes(IDs=[instanceID], need_confirmation=False)
        else:
            self.account.destroy_nodes(node_names=node_names, need_confirmation=False)

    def getBalance(self):
        # retrieve from database for the given account
        accumulating_cost, remaining_balance = self.account.get_cost_and_usage_from_db(user_name=self.user)
        return remaining_balance

    def getEstimateCost(self):
        pt = datetime.strptime(self.walltime, "%H:%M:%S")
        walltime_in_hours = int(pt.hour + pt.minute/60)
        if "midway3" in self.vendor_name:
            unit_price = 1.0 # float(self.account.get_unit_price(self.node_type))
        else:
            unit_price = float(self.account.get_unit_price(self.node_type))
        cost = walltime_in_hours * unit_price
        return cost
  
    def list_nodes(self):
        nodes, list_of_nodes = self.account.list_nodes(verbose=False) 
        return nodes

if __name__ == "__main__":

    colorama.init(autoreset=True)

    msg = "Skyway CLI interactive mode"
    parser = argparse.ArgumentParser(description=msg)
    parser.add_argument('-J', '--job-name', dest='jobname', default="your-run", help="Job name")
    parser.add_argument('-A', '--account', dest='account', default="", help="Account name")
    parser.add_argument('--provider', dest='provider', default="", help="Vendor name: AWS, GCP, Azure, or RCC Midway")
    parser.add_argument('--partition', dest='partition', default="", help="Partition")
    parser.add_argument('--constraint', dest='constraint', default="", help="Node type")
    parser.add_argument('-t', '--time', dest='walltime', default="", help="Walltime")
    
    args = parser.parse_args()

    job_name = args.jobname
    account_name = args.account
    node_type = args.constraint
    walltime = args.walltime
    provider = args.provider.lower()
    
    if provider == "":
        # try to infer the vendor name from account
        if 'aws' in account_name:
            vendor_name = "aws"
        elif 'gcp' in account_name:
            vendor_name = "gcp"
        elif 'azure' in account_name:
            vendor_name = "azure"
        elif 'oci' in account_name:
            vendor_name = "oci"
        elif 'midway3' in account_name or 'rcc-staff' in account_name:
            vendor_name = "rcc-midway3"
    else:
        vendor_name = provider

    # get the env variable SKYWAYROOT
    skywayroot = os.environ['SKYWAYROOT']
    if skywayroot == "":
        raise Exception("SKYWAYROOT is not defined.")

    # create an instance descriptor (like with the dashboard)
    instanceDescriptor = InstanceDescriptor(job_name, account_name, node_type, walltime, vendor_name)

    # submit job
    nodes = instanceDescriptor.submitJob()

    