#!python
#
# Copyright (c) 2021 Cisco Systems, Inc. and/or its affiliates
#
from pxgrid_util import PXGridControl
from pxgrid_util import Config
from pxgrid_util import create_override_url
from pxgrid_util import query
import time
import logging
import json
import sys
import asyncio
import aiohttp

logger = logging.getLogger(__name__)

# parallelism
NUM_COROUTINES = 20

# queue for all the MAC addresses to apply policy to
mac_queue = asyncio.Queue()

async def anc_applier(
            q=None,
            config=None,
            secret=None,
            url=None,
            policy=None):
    
    assert q is not None
    assert config is not None
    assert secret is not None
    assert url is not None
    assert policy is not None

    # what're we doing?
    while not q.empty():

        # create a session
        auth = aiohttp.BasicAuth(config.node_name, secret)
        conn = aiohttp.TCPConnector(ssl=config.ssl_context)
        async with aiohttp.ClientSession(auth=auth, connector=conn) as session:

            # pull from the queue until empty
            while not q.empty():
                m = await q.get()
                payload = json.dumps({
                    'macAddress': m,
                    'policyName': policy,
                })

                # applies the ANC policy
                response = await session.post(
                    url=url,
                    headers={
                        "Accept": "application/json",
                        "Content-Type": "application/json",
                    },
                    data=payload,
                )

                # tell the world how cool we are
                logger.info('Applied policy %s to %s, response = %d', policy, m, response.status)


async def apply_anc_policy_to_macs(
    config=None,
    secret=None,
    url=None,
    policy=None,
    mac_address_file=None,
    coroutines=NUM_COROUTINES):

    assert config is not None
    assert secret is not None
    assert url is not None
    assert policy is not None
    assert mac_address_file is not None

    # what're we doing?
    logger.info(
        'Applying policy %s, reading MAC addresses from %s, using %d coroutines',
        policy,
        mac_address_file,
        coroutines)

    # enqueue all the MAC addresses
    with open(mac_address_file, 'r') as macs:
        for m in macs.readlines():
            m = m[:-1]
            mac_queue.put_nowait(m)

    # start all the coroutines
    anc_tasks = [
        asyncio.create_task(anc_applier(
            q=mac_queue,
            config=config,
            secret=secret,
            url=url,
            policy=policy))
        for t in range(0, coroutines)
    ]

    # wait for the workers to be done
    await asyncio.gather(*anc_tasks)


def apply_bulk_anc_policy_by_mac(config, secret, url, bulk_policy, bulk_mac_addrs_file):
    '''
    Apply ANC policy in bulk using `asyncio` techniques.
    '''
    asyncio.run(
        apply_anc_policy_to_macs(
            config=config,
            secret=secret,
            url=url,
            policy=bulk_policy,
            mac_address_file=bulk_mac_addrs_file,
            coroutines=NUM_COROUTINES,
    ))    


if __name__ == '__main__':
    config = Config()
    
    #
    # verbose logging if configured
    #
    if config.verbose:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)

        # and set for stomp and ws_stomp modules also
        for modname in ['pxgrid_util.stomp', 'pxgrid_util.ws_stomp', 'pxgrid_util.pxgrid']:
            s_logger = logging.getLogger(modname)
            handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
            s_logger.addHandler(handler)
            s_logger.setLevel(logging.DEBUG)

    pxgrid = PXGridControl(config=config)

    while pxgrid.account_activate()['accountState'] != 'ENABLED':
        time.sleep(60)

    # lookup for session service
    service_lookup_response = pxgrid.service_lookup('com.cisco.ise.config.anc')
    service = service_lookup_response['services'][0]
    node_name = service['nodeName']


    secret = pxgrid.get_access_secret(node_name)['secret']
    logger.info('Using access secret %s', secret)
    payload = {}
    bulk_policy = None
    bulk_mac_addrs_file = None

    if config.get_anc_endpoints:
        url = service['properties']['restBaseUrl'] + '/getEndpoints'

    elif config.get_anc_policies:
        url = service['properties']['restBaseUrl'] + '/getPolicies'

    elif config.apply_anc_policy:
        assert config.anc_policy
        assert config.anc_mac_address
        url = service['properties']['restBaseUrl'] + '/applyEndpointPolicy'
        payload['macAddress'] = config.anc_mac_address
        payload['policyName'] = config.anc_policy

    elif config.apply_anc_policy_by_mac:
        assert config.anc_policy
        assert config.anc_mac_address
        url = service['properties']['restBaseUrl'] + '/applyEndpointByMacAddress'
        payload['macAddress'] = config.anc_mac_address
        payload['policyName'] = config.anc_policy

    elif config.apply_anc_policy_by_mac_bulk:
        assert config.anc_policy
        url = service['properties']['restBaseUrl'] + '/applyEndpointByMacAddress'
        bulk_policy = config.anc_policy
        bulk_mac_addrs_file = config.apply_anc_policy_by_mac_bulk

    elif config.apply_anc_policy_by_ip:
        assert config.anc_policy
        assert config.anc_ip_address
        url = service['properties']['restBaseUrl'] + '/applyEndpointByIpAddress'
        payload['ipAddress'] = config.anc_ip_address
        payload['policyName'] = config.anc_policy

    elif config.clear_anc_policy_by_mac:
        assert config.anc_mac_address
        url = service['properties']['restBaseUrl'] + '/clearEndpointByMacAddress'
        payload['macAddress'] = config.anc_mac_address

    elif config.clear_anc_policy_by_ip:
        assert config.anc_ip_address
        url = service['properties']['restBaseUrl'] + '/clearEndpointByIpAddress'
        payload['ipAddress'] = config.anc_ip_address

    elif config.create_anc_policy:
        assert config.anc_policy_action
        url = service['properties']['restBaseUrl'] + '/createPolicy'
        payload['name'] = config.create_anc_policy
        payload['actions'] = [ config.anc_policy_action.__str__() ]

    elif config.delete_anc_policy:
        url = service['properties']['restBaseUrl'] + '/deletePolicyByName'
        payload['name'] = config.delete_anc_policy

    elif config.get_anc_policy_by_mac:
        assert config.anc_mac_address
        url = service['properties']['restBaseUrl'] + '/getEndpointByMacAddress'
        payload['macAddress'] = config.anc_mac_address

    else:
        logger.debug('no valid options for getting, applying or removing ANC policy')
        sys.exit(1)

    # log url to see what we get via discovery
    logger.info('Using URL %s', url)

    # check to see if we need to override the URL
    if config.discovery_override:
        url = create_override_url(config, url)

    # make the request!!
    if not bulk_policy:
        payload = json.dumps(payload)
        logger.info('payload = %s', payload)
        resp = query(config, secret, url, payload)
        if len(resp) != 0:
            print(json.dumps(json.loads(resp), indent=2, sort_keys=True))
        else:
            print('{}')
    else:
        apply_bulk_anc_policy_by_mac(config, secret, url, bulk_policy, bulk_mac_addrs_file)
