#!/usr/bin/env python3
import pyshark
import time
import asyncio
import logging
import sys
import json
import netifaces
# import psutil
from pathlib import Path
from pxgrid_pyshark_test import endpointsdb
from pxgrid_pyshark_test import parser
from pxgrid_util import WebSocketStomp
from pxgrid_util import Config
from pxgrid_util import create_override_url
from pxgrid_util import PXGridControl
from websockets.exceptions import WebSocketException
from signal import SIGINT, SIGTERM

logger = logging.getLogger(__name__)
# default_filter = '!ipv6 && (ssdp || (http && http.user_agent != "") || sip || xml || (mdns && dns.resp.type == 16))'
default_filter = '!ipv6 && (ssdp || (http && http.user_agent != "") || sip || (mdns && dns.resp.type == 16))'
parser = parser()

## Create dict of supported protocols and their appropriate inspection functions
packet_callbacks = {
    'sip': parser.parse_sip,
    'ssdp': parser.parse_ssdp,
    'mdns': parser.parse_mdns,
    'http': parser.parse_http,
    'xml': parser.parse_xml
}

## Process network packets using global Parser instance and dictionary of supported protocols
def process_packet(packet):
    # logger.debug(f'packet received: {packet.highest_layer}')
    try:
        for layer in packet.layers:
            fn = packet_callbacks.get(layer.layer_name.lower())
            if fn is not None:
                # logger.debug(f'includes packet layer: {layer.layer_name} - fn={fn}')
                endpoints.update_db_list(fn(packet))
    except Exception as e:
        logger.debug(f'exception occured: {e} with {packet.highest_layer} packet')

async def default_service_reregister_loop(config, pxgrid, service_id, reregister_delay):
    '''
    Simple custom service reregistration to keep things alive.
    '''
    try:
        while True:
            await asyncio.sleep(reregister_delay)
            try:
                resp = pxgrid.service_reregister(service_id)
                logger.debug(
                    '[default_service_reregister_loop] service reregister response %s',
                    json.dumps(resp))
            except Exception as e:
                logger.debug(
                    '[default_service_reregister_loop] failed to reregister, Exception: %s',
                    e.__str__())

            # pull service back to check
            service_lookup_response = pxgrid.service_lookup(config.service)
            service = service_lookup_response['services'][0]
            debug_text = json.dumps(resp, indent=2, sort_keys=True)
            for debug_line in debug_text.splitlines():
                logger.debug('[default_publish_loop] service_register_response %s', debug_line)

    except asyncio.CancelledError as e:
        logger.debug('[default_service_reregister_loop] reregister loop cancelled')

async def default_publish_loop(config, secret, pubsub_node_name, ws_url, topic):
    '''
    Simple publish loop just to send some canned data.
    '''
    if config.discovery_override:
        logger.info('[default_publish_loop] overriding original URL %s', ws_url)
        ws_url = create_override_url(config, ws_url)
        logger.info('[default_publish_loop] new URL %s', ws_url)

    logger.debug('[default_publisher_loop] starting subscription to %s at %s', topic, ws_url)

    logger.debug('[default_publish_loop] opening web socket and stomp')
    ws = WebSocketStomp(
        ws_url,
        config.node_name,
        secret,
        config.ssl_context,
        # ping_interval=None)
        ping_interval=config.ws_ping_interval)

    try:
        logger.debug('[default_publish_loop] connect websocket')    
        await ws.connect()
        logger.debug('[default_publish_loop] connect STOMP node %s', pubsub_node_name)    
        await ws.stomp_connect(pubsub_node_name)
    except Exception as e:
        logger.debug('[default_publish_loop] failed to connect, Exception: %s', e.__str__())
        return
    try:
        count = 0
        while True:
            await asyncio.sleep(5.0)
            logger.debug('obtaining endpoints from local db to send to ISE')
            results = await endpoints.get_active_entries()
            logger.debug(f'local db records pending update to ISE: {len(results)}')
            if results:
                for row in results:
                    message = {
                        "opType": "UPDATE",
                        "asset": {
                            "assetId": row[3],
                            "assetName": row[4],
                            "assetIpAddress": row[2],
                            "assetMacAddress": row[0],
                            "assetVendor": row[5],
                            "assetHwRevision": row[6],
                            "assetSwRevision": row[7],
                            "assetProtocol": row[1],
                            "assetProductId": row[8],
                            "assetSerialNumber": row[9],
                            "assetDeviceType": row[10]
                        }
                    }
                    try:
                        await ws.stomp_send(topic, json.dumps(message))
                        logger.debug(f'ISE Endpoint Updated: {row[0]}, {row[2]}')
                        count += 1
                        await endpoints.ise_endpoint_updated(row[0])
                    except Exception as e:
                        logger.debug(
                            '[default_publish_loop] Exception: %s',
                            e.__str__())
                logger.debug(f'endpoint updates sent to ISE: {str(count)}')
            logger.debug(
                '[default_publish_loop] message published to node %s, topic %s',
                pubsub_node_name,
                topic)
            sys.stdout.flush()
    except asyncio.CancelledError as e:
        pass
    except WebSocketException as e:
        logger.debug(
            '[default_publish_loop] WebSocketException: %s',
            e.__str__())
        return
    
    logger.debug('[default_publish_loop] shutting down publisher...')
    await ws.stomp_disconnect('123')
    await asyncio.sleep(2.0)
    await ws.disconnect()

## Process a given PCAP(NG) file with a provided PCAP filter
def process_capture_file(capture_file, capture_filter):
    if Path(capture_file).exists():
        #logger.debug(f'processing capture file: {capture_file}')
        start_time = time.perf_counter()
        capture = pyshark.FileCapture(capture_file, display_filter=capture_filter, only_summaries=False, include_raw=True, use_json=True)
        currentPacket = 0
        for packet in capture:
            ## Wrap individual packet processing within 'try' statement to avoid formatting issues crashing entire process
            try:
                process_packet(packet)
            except TypeError as e:
                logger.debug(f'Error processing packet: {capture_file}, packet # {currentPacket}: TypeError: {e}')
            currentPacket += 1
        capture.close()
        end_time = time.perf_counter()
        #logger.debug(f'processing capture file complete: execution time: {end_time - start_time:0.6f} : {currentPacket} packets processed ##')
    else:
        logger.debug(f'capture file not found: {capture_file}')

def capture_live_packets(network_interface, filter):
    capture = pyshark.LiveCapture(interface=network_interface, only_summaries=False, include_raw=True, use_json=True, display_filter=filter)
    currentPacket = 0
    for packet in capture.sniff_continuously():
        process_packet(packet)
        currentPacket += 1

if __name__ == '__main__':
    ## Parse all of the CLI options provided
    config = Config()

    ## Add additional arguments to pxgrid_util Config class for pyshark funtionality
    g = config.parser.add_mutually_exclusive_group(required=True)
    g.add_argument(
        '--interface',
        help='Network interface receiving traffic to be analyzed')
    g = config.parser.add_mutually_exclusive_group(required=False)
    g.add_argument(
        '--filter', 
        help='Wireshark-formatted filter to be applied to received traffic',
        default=default_filter)

    config.parse_args()

    ## Verbose logging if configured
    if config.verbose:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)

        # and set for stomp and ws_stomp modules and sub-CLASSES of pxgrid-pyshark
        for modname in ['pxgrid_util.stomp', 'pxgrid_util.ws_stomp', 'pxgrid_util.pxgrid', 'pxgrid_pyshark_test.parser', 'pxgrid_pyshark_test.endpointsdb', 'pxgrid_pyshark_test.ouidb']:
            s_logger = logging.getLogger(modname)
            handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
            s_logger.addHandler(handler)
            s_logger.setLevel(logging.DEBUG)

    ## Verify required attributes provided via CLI
    if not config.hostname:
        print("No hostname!")
        sys.exit(0)
    if not config.node_name:
        print("No nodename provided (aka. pxgrid account username)")
        sys.exit(0)
    if not config.service:
        config.config.service = 'com.cisco.endpoint.asset'
        logger.debug(f'using default pxgrid service: com.cisco.endpoint.asset')
    if not config.topic:
        config.config.topic = 'asset'
        logger.debug(f'using default pxgrid topic: asset')
    if not config.config.interface:
        print("No capture interface provided")
        sys.exit(1)
    else:
        capture_int = config.config.interface
        ints = netifaces.interfaces()
        if capture_int not in ints:
            print(f'Invalid interface name provided: {capture_int}.')
            print(f'Valid interface names are: {ints}')
            sys.exit(1)
        # if check_interface(capture_int) is False:
        #     print(f'Invalid interface name provided: {capture_int}.')
        #     sys.exit(1)
        logger.debug(f'using capture interface = {capture_int}')
    
    capture_filter = config.config.filter
    logger.debug(f'using capture filter = {capture_filter}')
    config.parse_args()    
    

    ## Setup pxGrid control object
    pxgrid = PXGridControl(config=config)
    ## Ensure account provided is approved in ISE UI
    while pxgrid.account_activate()['accountState'] != 'ENABLED':
        time.sleep(60)
    ## Register a custom service
    properties = {
        'wsPubsubService': 'com.cisco.ise.pubsub',
        f'{config.topic}': f'/topic/{config.service}',
    }
    resp = pxgrid.service_register(config.service, properties)
    debug_text = json.dumps(resp, indent=2, sort_keys=True)
    for debug_line in debug_text.splitlines():
        logger.debug('[service_register_response] %s', debug_line)
    ## Setup periodic service reregistration as a task
    reregister_task = asyncio.ensure_future(
        default_service_reregister_loop(
            config,
            pxgrid,
            resp['id'],
            config.reregister_delay,
    ))

    ## Lookup service and topic details for the service we just registered
    service_lookup_response = pxgrid.service_lookup(config.service)
    slr_string = json.dumps(service_lookup_response, indent=2, sort_keys=True)
    logger.debug('service lookup response:')
    for s in slr_string.splitlines():
        logger.debug('  %s', s)
    service = service_lookup_response['services'][0]
    pubsub_service_name = service['properties']['wsPubsubService']
    try:
        topic = service['properties'][config.topic]
    except KeyError as e:
        logger.debug('invalid topic %s', config.topic)
        possible_topics = [
            k for k in service['properties'].keys()
            if k != 'wsPubsubService' and k != 'restBaseUrl' and k != 'restBaseURL'
        ]
        logger.debug('possible topic handles: %s', ', '.join(possible_topics))
        sys.exit(1)

    ## Lookup the pubsub service
    service_lookup_response = pxgrid.service_lookup(pubsub_service_name)

    ## Use the first pubsub service node returned (there is randomness)
    pubsub_service = service_lookup_response['services'][0]
    pubsub_node_name = pubsub_service['nodeName']
    secret = pxgrid.get_access_secret(pubsub_node_name)['secret']
    ws_url = pubsub_service['properties']['wsUrl']

    ## Setup the publishing loop
    main_task = asyncio.ensure_future(
        default_publish_loop(
            config,
            secret,
            pubsub_node_name,
            ws_url,
            topic,
    ))

    ## Setup sigint/sigterm handlers
    def signal_handlers():
        main_task.cancel()
        reregister_task.cancel()
    loop = asyncio.get_event_loop()
    loop.add_signal_handler(SIGINT, signal_handlers)
    loop.add_signal_handler(SIGTERM, signal_handlers)

    ## Create the local DB file for storing parsed packets
    logger.debug('building databases')
    endpoints = endpointsdb()
    logger.debug('building databases complete')

    process_capture_file('/Users/aacook/Desktop/_LAB/paul-long.pcapng', capture_filter)
    process_capture_file('/Users/aacook/Desktop/_LAB/DuckNET.pcapng', capture_filter)
    process_capture_file('/Users/aacook/Desktop/_LAB/Canopy.pcapng', capture_filter)
    process_capture_file('/Users/aacook/Desktop/_LAB/eth1-20180320.pcap', capture_filter)
    process_capture_file('/Users/aacook/Desktop/_LAB/XboxOne.pcapng', capture_filter)

    ## Begin the capture on the indicated interface (replace w/ relevant interface name)
    logger.debug('begin live capture')
    capture_live_packets(capture_int, capture_filter)

    try:
        loop.run_until_complete(main_task)
    except:
        pass
    logger.debug('### FINAL OUTPUT ###')

    ## Provide output of all entries within local DB and stats for update messages
    endpoints.view_all_entries()
    endpoints.view_stats()