#!/usr/bin/env python
import argparse
from datetime import datetime, timezone
from io import StringIO
import logging
import os
from subprocess import PIPE, Popen

import skyway
from skyway.cloud.aws import *
from skyway.cloud.gcp import *
from skyway.cloud.azure import *
from skyway.cloud.slurm import *
from skyway.cloud.oci import *

import colorama
from colorama import Fore

# export SKYWAYROOT=/project/rcc/trung/skyway-github
#./skyway_list --account=rcc-aws

class InstanceDescriptor:
    def __init__(self, jobname: str, account_name: str, node_type: str, walltime: str, vendor_name: str):
        self.jobname = jobname
        self.account_name = account_name
        self.node_type = node_type
        self.walltime = walltime
        self.vendor_name = vendor_name

        self.account = None
        if 'aws' in vendor_name:
            self.account = AWS(account_name)
        elif 'gcp' in vendor_name:
            self.account = GCP(account_name)
        elif 'azure' in vendor_name:
            self.account = AZURE(account_name)
        elif 'oci' in vendor_name:
            self.account = OCI(account_name)
        elif 'midway3' in vendor_name:
            self.account = SLURMCluster(account_name)

        self.user = os.environ['USER']

    def getBalance(self):
        # retrieve from database for the given account
        accumulating_cost, remaining_balance = self.account.get_cost_and_usage_from_db(user_name=self.user)
        return remaining_balance

    def getEstimateCost(self):
        pt = datetime.strptime(self.walltime, "%H:%M:%S")
        walltime_in_hours = int(pt.hour + pt.minute/60)
        if "midway3" in self.vendor_name:
            unit_price = 1.0 # float(self.account.get_unit_price(self.node_type))
        else:
            unit_price = float(self.account.get_unit_price(self.node_type))
        cost = walltime_in_hours * unit_price
        return cost
  
    def transferData(self, node_names, local_data, from_cloud=False, cloud_path=""):
        '''
        only get on a on-premise compute node for now with rcc-staff
        '''
        
        if "midway3" in self.vendor_name:
            instanceID = self.account.get_host_ip(self.jobname)
            node_info = self.account.get_node_connection_info(instanceID)

        elif "aws" in self.vendor_name or "gcp" in self.vendor_name:
            instanceID = self.account.get_instance_ID(self.jobname)
            node_info = self.account.get_node_connection_info(instanceID)
            private_key = node_info['private_key']
            remote = node_info['login']
            local = ' '.join(local_data)

            if from_cloud == True:
                # copy from cloud
                cmd = f"scp -rC -i {private_key} -o StrictHostKeyChecking=accept-new {remote}:{cloud_path} {local_data[0]}"
                
            else:
                # copy to cloud
                if cloud_path == "":
                    cmd = f"scp -rC -i {private_key} -o StrictHostKeyChecking=accept-new {local} {remote}:~/"
                else:
                    cmd = f"scp -rC -i {private_key} -o StrictHostKeyChecking=accept-new {local} {remote}:/{cloud_path}"
               
            print(f"Executing: {cmd}")
            os.system(cmd)

        elif "oci" in self.vendor_name:
            instanceID = self.account.get_instance_ID(self.jobname)
            node_info = self.account.get_node_connection_info(instanceID)
            private_key = node_info['private_key']
            remote = node_info['login']
            datafile = ' '.join(data)
            cmd = f"scp  -i {private_key} -o StrictHostKeyChecking=accept-new {local} {remote}:~/"
            os.system(cmd)

if __name__ == "__main__":

    msg = "Skyway list all the running VMs of a cloud account"
    parser = argparse.ArgumentParser(description=msg)
    parser.add_argument('-J', '--job-name', dest='jobname', default="your-run", help="Job name")
    parser.add_argument('-A', '--account', dest='account', default="", help="Account name")
    parser.add_argument('--provider', dest='provider', default="", help="Vendor name: AWS, GCP, Azure, or RCC Midway")
    parser.add_argument('--from-cloud', dest='from_cloud', action='store_true', default=False, help="Copy data from cloud if specified")
    parser.add_argument('--cloud-path', dest='cloud_path', default="", help="Path to cloud space, empty for $HOME")
    parser.add_argument(dest='data', nargs='*', default="", help="Data to transfer to the VM")

    args = parser.parse_args()

    job_name = args.jobname
    account_name = args.account
    provider = args.provider.lower()
    data = args.data
    from_cloud = args.from_cloud
    cloud_path = args.cloud_path

    if provider == "":
        # try to infer the vendor name from account
        if 'aws' in account_name:
            vendor_name = "aws"
        elif 'gcp' in account_name:
            vendor_name = "gcp"
        elif 'azure' in account_name:
            vendor_name = "azure"
        elif 'oci' in account_name:
            vendor_name = "oci"
        elif 'midway3' in account_name or 'rcc-staff' in account_name:
            vendor_name = "rcc-midway3"
    else:
        vendor_name = provider

    # get the env variable SKYWAYROOT
    skywayroot = os.environ['SKYWAYROOT']
    if skywayroot == "":
        raise Exception("SKYWAYROOT is not defined.")

    # create an instance descriptor (like with the dashboard)
    instanceDescriptor = InstanceDescriptor(job_name, account_name, "", "", vendor_name)
    
    # transfer 
    node_names = [job_name]
    instanceDescriptor.transferData(node_names=node_names, local_data=data, from_cloud=from_cloud, cloud_path=cloud_path)
    