#!/usr/bin/env python3
import pyshark
import time
import logging
from pathlib import Path
from pxgrid_pyshark_test import endpointsdb
from pxgrid_pyshark_test import parser

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
default_filter = '!ipv6 && (ssdp || (http && http.user_agent != "") || sip || xml || (mdns && dns.resp.type == 16))'
parser = parser()
currentPacket = 0

packet_callbacks = {
    'sip': parser.parse_sip,
    'ssdp': parser.parse_ssdp,
    'mdns': parser.parse_mdns,
    'http': parser.parse_http,
    'xml': parser.parse_xml
}

## Process network packets using global Parser instance and dictionary of supported protocols
def process_packet(packet):
    ## TO DO: NEED TO RESOLVE ISSUES WITH XML PARSING (matches HTTP packet first)
    # highest_layer = packet.highest_layer
    # inspection_layer = str(highest_layer).split('_')[0]
    # if inspection_layer == 'xml':
    #     endpoints.update_db_list(parser.parse_xml(packet))
    # i = 1
    for layer in packet.layers:
        fn = packet_callbacks.get(layer.layer_name)
        if fn is not None:
            endpoints.update_db_list(fn(packet))

## Process a given PCAP(NG) file with a provided PCAP filter
def process_capture_file(capture_file, capture_filter):
    if Path(capture_file).exists():
        logger.debug(f'processing capture file: {capture_file}')
        start_time = time.perf_counter()
        capture = pyshark.FileCapture(capture_file, display_filter=capture_filter, only_summaries=False, include_raw=True, use_json=True)
        currentPacket = 0
        for packet in capture:
            ## Wrap individual packet processing within 'try' statement to avoid formatting issues crashing entire process
            try:
                process_packet(packet)
                # print(f'Packet# {currentPacket}, {packet.highest_layer}')
                # i = 1
            except TypeError as e:
                logger.debug(f'Error processing packet: {capture_file}, packet # {currentPacket}: TypeError: {e}')
            currentPacket += 1
        capture.close()
        end_time = time.perf_counter()
        logger.debug(f'processing capture file complete: execution time: {end_time - start_time:0.6f} : {currentPacket} packets processed ##')
    else:
        logger.debug(f'capture file not found: {capture_file}')

# ## Process network traffic received on a given interface with the provided PCAP filter
# def capture_live_packets(network_interface, filter):
#     capture = pyshark.LiveCapture(interface=network_interface, only_summaries=False, include_raw=True, use_json=True, display_filter=filter)
#     currentPacket = 0
#     for packet in capture.sniff_continuously():
#         try:
#             process_packet(packet)
#         except TypeError as e:
#             logger.debug(f'Error processing packet. TypeError: {e}')
#         currentPacket += 1

if __name__ == '__main__':
    ## Create the local DB file for storing parsed packets

    debugMode = True

    if debugMode:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)

        # Extend the logger function to the sub-CLASSES
        for modname in ['pxgrid_pyshark_test.parser', 'pxgrid_pyshark_test.endpointsdb', 'pxgrid_pyshark_test.ouidb']:
            s_logger = logging.getLogger(modname)
            handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
            s_logger.addHandler(handler)
            s_logger.setLevel(logging.DEBUG)
    
    logger.debug('building databases')
    endpoints = endpointsdb()              ## Initialize the endpoints database
    logger.debug('building databases complete')

    ### CAPTURE FILE ###
    print('#################################')
    filename = input('Input local pcap(ng) file to be parsed: ')
    filter = input('Input custom wireshark filter (leave blank to use built-in filter): ')
    if filter == '':
        filter = default_filter
    print('#################################')

    process_capture_file(filename, filter)

    logger.debug('### FINAL OUTPUT ###')
    endpoints.view_all_entries()
    endpoints.view_stats()