#!python
# SPDX-FileCopyrightText: (C) 2022 Avnet Embedded GmbH
# SPDX-License-Identifier: GPL-3.0-only

# Select a VM to start, start it, and return its name

import argparse
import json
import logging
import multiprocessing
import os
import requests
from functools import partial
from random import randint
from requests.adapters import HTTPAdapter, Retry
from time import sleep

AZURE_AVAIL_STATUS = [
    'PowerState/deallocated',
    'PowerState/deallocating',
    'PowerState/stopped',
]

headers = {
    "X-GitHub-Api-Version": "2022-11-28",
    "Authorization": f"Bearer {os.getenv('GITHUB_PAT')}",
    "Accept": "application/vnd.github+json",
}

logging.getLogger().setLevel(logging.INFO)

session = requests.Session()
retries = Retry(
    total=5,
    backoff_factor=1,
    status_forcelist=[429, 500, 502, 503, 504],
)
session.mount('https://', HTTPAdapter(max_retries=retries))


class AzureClient():
    """Class to handle connection to Azure API"""

    def __init__(self, azure_token: str):
        """Class Constructor"""
        self.headers = {
            "content_type": "application/json",
            "authorization": f"bearer {azure_token}",
        }

    def get_vms(self, subscription_id: str, ressource_group_name: str) -> list[dict]:
        """Retrieve the VMs in the specified subscription and resoure group"""
        try:
            virtual_machines = session.get(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            ).json().get('value', [])
        except Exception:
            logging.exception(
                "Got error when requesting virtual machines from Azure, does the token have all permissions to get them?")
            exit(1)

        for vm in virtual_machines:
            statuses = session.get(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/{vm['name']}/instanceView?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            ).json().get('statuses', [])
            status = [x for x in statuses if 'PowerState' in x.get('code', '')]
            if status:
                vm['status'] = status[0]['code']
            else:
                vm['status'] = None

        return virtual_machines

    def start_vm(self, subscription_id: str, ressource_group_name: str, vm: dict, start_mode: str) -> bool:
        """Start the selected VM

        Args:
            subscription_id (str): the subscription in which we are currently working
            ressource_group_name (str): the resource group in which we are currently working
            vm (dict): the VM object as sent by Azure

        Returns:
            bool: whether the VM is currently running
        """
        logging.info(f"Starting VM {vm['name']}")
        while True:
            try:
                post_action = session.post(
                    f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/{vm['name']}/{start_mode}?api-version=2021-03-01",  # noqa: E501
                    headers=self.headers,
                )
                break
            except (requests.exceptions.ConnectionError):
                sleep(randint(1, 5))
        logging.info(
            f"Started VM {vm['name']} - {post_action.status_code} {post_action.text}",
        )

        print(json.dumps({
            "name": vm['name'],
            "status": vm['status'],
            "gh_runner_name": vm['tags'].get('GH_RUNNER_NAME', ''),
        }))
        statuses = session.get(
                f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{ressource_group_name}/providers/Microsoft.Compute/virtualMachines/{vm['name']}/instanceView?api-version=2021-03-01",  # noqa: E501
                headers=self.headers,
            ).json().get('statuses', [])
        status = any([
            x for x in statuses if 'PowerState/running' in x.get('code', '')])
        return status


def get_args() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        prog="start-vm",
        description="Start all VMs in Azure",
    )
    parser.add_argument('subscription_id', help="Azure Subscription ID")
    parser.add_argument('ressource_group_name',
                        help="Azure Ressource Group Name")
    parser.add_argument('tag', help="VM tag to filter")
    parser.add_argument('number', help="Number of VM to start", type=int)
    parser.add_argument(
        '-t', '--token',
        help="Azure Token, may also be passed through the environment variable AZ_BEARER",
        default='',
    )

    return parser.parse_args()


def check_gh_service(vm: dict) -> bool:
    """Check GH API to get status of runners"""
    for _ in range(3):
        check_gh_runner = requests.get(
            "https://api.github.com/orgs/avnet-embedded/actions/runners",
            headers=headers,
            ).json()
        for runner in check_gh_runner['runners']:
            if vm['name'] in runner['name']:
                if runner['status'] == "online":
                    return True
                sleep(60)
    return False


def start_and_check(args: argparse.Namespace, client: AzureClient, vm: dict):
    """Each process will start VM and GH services"""
    gh_service = True
    is_vm_running = client.start_vm(
        args.subscription_id, args.ressource_group_name, vm, "start")
    sleep(120)
    for i in range(0, 4):
        if is_vm_running:
            if check_gh_service(vm):
                break
            elif i == 4:
                gh_service = False
        elif is_vm_running and not gh_service:
            is_vm_running = client.start_vm(
                args.subscription_id, args.ressource_group_name, vm, "restart")
            sleep(120)
    return gh_service


def main():
    """Main function"""
    args = get_args()

    # sanity check on builds requested
    args.number = max(1, args.number)

    azure_token = args.token or os.getenv('AZ_BEARER')
    if not azure_token:
        logging.error(
            'Error: an Azure token is required, you may pass it as an argument or through the AZ_BEARER environment variable',
        )
        exit(1)

    client = AzureClient(azure_token)

    virtual_machines = client.get_vms(
        args.subscription_id, args.ressource_group_name)

    logging.info(f"""All vms status: {[{
                        "name": vm['name'],
                        "status": vm['status'],
                        "tags": vm['tags'],
                    } for vm in virtual_machines]}""",
                 )
    # Checking if we have a VM currently stopped and ready to be powered on, and start it
    available_vms = [
        vm for vm in virtual_machines if vm['status'] in AZURE_AVAIL_STATUS and args.tag == vm.get("tags", {}).get("ROLE", "")
    ]
    if len(available_vms) > args.number:
        available_vms = available_vms[:args.number]

    logging.info(
        f'{args.number} builds requested, starting {[x["name"] for x in available_vms]}')
    pool = multiprocessing.Pool(processes=args.number)
    failed_vm = [pool.map(partial(start_and_check, args, client), available_vms)].count('False')
    logging.info(f"""{failed_vm} VM failed to start on {args.number} VM requested""")
    return 0


if __name__ == "__main__":
    exit(main())
