#!/usr/bin/env python3
import select
import sys
import textwrap
import argparse
import logging
from os import makedirs
from os.path import expanduser, exists, join
from shutil import copyfile

from privex.helpers import ErrHelpParser, empty, empty_if
from privex.pyrewall import conf, VERSION
from privex.pyrewall.conf import FILE_SUFFIX, CONF_DIRS, SEARCH_DIRS, SERVICE_FILE, SERVICE_FILE_DEST
from privex.pyrewall.core import find_file, load_rules, save_rules, search_files, is_root, run_prog, run_prog_ex
from privex.pyrewall.PyreParser import PyreParser
from privex.pyrewall.exceptions import ReturnCodeError
from privex.pyrewall.repl import repl_main
from typing import Union, Tuple, Dict, List
from io import TextIOWrapper
from datetime import datetime

log = logging.getLogger('privex.pyrewall.repl')


CMD_DESC = {
    'parse': f'Parse a {FILE_SUFFIX} file and output rules compatible with iptables-restore',
    'load': f'(Re-)load a Pyrewall {FILE_SUFFIX} file with iptables-restore',
    'install_service': f"(RUN AS ROOT) Install, enable, and start the systemd service from {SERVICE_FILE} into {SERVICE_FILE_DEST}",
}

CONF_DIR_LIST = "\n".join("   - " +c for c in CONF_DIRS)
SEARCH_DIR_LIST = "\n".join("   - " +c for c in SEARCH_DIRS)

HELP_TEXT = textwrap.dedent(f'''\

PyreWall Version v{VERSION}
(C) 2020 Privex Inc. ( https://wwww.privex.io )
Official Repo: https://github.com/Privex/pyrewall


Sub-commands:

    parse  (-i 4|6) [filename]       - {CMD_DESC['parse']}
    load   (-i 4|6) (-n) (filename)  - {CMD_DESC['load']}

CONF_DIRS: 
{CONF_DIR_LIST}

SEARCH_DIRS: 
{SEARCH_DIR_LIST}

''')

parser = ErrHelpParser(
    description='PyreWall - Python firewall management using iptables',
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog=HELP_TEXT
)


def parse_stdin(ipver='both'):
    lines = []
    for l in sys.stdin:
        lines.append(l.strip())
    p = PyreParser()
    ip4, ip6 = p.parse_lines(lines=lines)
    print_rules(ip4=ip4, ip6=ip6, ipver=ipver)


def err(*msgs: str, file=sys.stderr, **kwargs):
    print(*msgs, **kwargs, file=file)


def timeout_input(prompt: str = "Continue? (y/N)", timeout: int = 3, default: str = "n") -> Tuple[int, str]:
    print(prompt, end=': ', flush=True)
    inputs, outputs, errors = select.select([sys.stdin], [], [], timeout)
    print()
    return (0, sys.stdin.readline().strip()) if inputs else (-1, default)


class RuleOutput:
    VER_TUPLE = Tuple[List[str], List[str]]

    def __init__(self, opt: argparse.Namespace):
        super().__init__()
        self.ip_ver = opt.ipver
        self.input_file = opt.file
        
        self.output_file = opt.output if 'output' in opt else None
        self.output_file4 = opt.output4 if 'output4' in opt else None
        self.output_file6 = opt.output6 if 'output6' in opt else None

        self.input_stream = None
        self.output_stream = None
        self.output_stream4 = None
        self.output_stream6 = None

        self.rules_v4 = []
        self.rules_v6 = []
    
    @property
    def using_v4(self):
        return self.ip_ver in ['4', 'v4', 'ipv4', 'both']
    
    @property
    def using_v6(self):
        return self.ip_ver in ['6', 'v6', 'ipv6', 'both'] 

    @staticmethod
    def _get_stream(direction: str, dest: str, overwrite=False) -> TextIOWrapper:
        modes = 'r'

        if direction == 'out':
            if dest == '-': 
                return sys.stdout
            
            modes = 'w' if overwrite else 'x'
        elif direction == 'in':
            if dest == '-': 
                return sys.stdin
        else:
            raise AttributeError('direction must be "in" or "out".')

        return open(dest, modes)
    
    def parse_stream(self, stream: TextIOWrapper = None) -> VER_TUPLE:
        stream = self.input_stream if stream is None else stream
        stream = sys.stdin if stream is None else stream

        lines = []
        for l in stream.readlines():
            lines.append(l.strip())
        return PyreParser().parse_lines(lines=lines)
        # print_rules(ip4=ip4, ip6=ip6, ipver=ipver)
    
    def parse_file(self, file=None) -> VER_TUPLE:
        f = self.input_file if file is None else file
        try:
            path = find_file(f, SEARCH_DIRS, extensions=conf.SEARCH_EXTENSIONS)
        except FileNotFoundError:
            err(f'ERROR: The file "{f}" could not be found in any of your search directories.')
            return sys.exit(1)
        err(f'Parsing file: {path}')
        p = PyreParser()
        return p.parse_file(path=path)

    @staticmethod
    def gen_start_line(filename: str, timestamp=None):
        if not timestamp:
            timestamp = datetime.utcnow()
        timestamp = timestamp.replace(microsecond=0)
        return f'### Generated by PyreWall from file: "{filename}" at date/time: {timestamp.isoformat(" ")} UTC-0'

    @staticmethod
    def backup_rules(ipver='v4', backup_file=None, backup_dir='~/.pyrewall'):
        backup_dir = expanduser(backup_dir)
        backup_file = empty_if(backup_file, f"old_rules.{ipver}")
        
        if not exists(backup_dir):
            log.debug("Creating folder %s", backup_dir)
            makedirs(backup_dir)
        
        bk_path = join(backup_dir, backup_file)
        rules = save_rules(ipver)
        
        with open(bk_path, 'w') as fh:
            fh.writelines([f"{l}\n" for l in rules])
        
        return bk_path

    def load(self, file=None, confirm=True, timeout=15):
        f = self.input_file if file is None else file
        
        if f == '-' or (empty(f) and not sys.stdin.isatty()):
            self.input_stream = sys.stdin
            ip4, ip6 = self.parse_stream(stream=self.input_stream)
        elif empty(f):
            try:
                f = search_files(*conf.MAIN_PYRE)
                ip4, ip6 = self.parse_file(file=f)
            except FileNotFoundError:
                err(f"No filename was specified to load, and none of the MAIN_PYRE files could be found in the SEARCH_DIRS.\n")
                err(f"MAIN_PYRE: {conf.MAIN_PYRE}\n")
                err(f"SEARCH_DIRS: {conf.SEARCH_DIRS}\n")
                err()
                return sys.exit(1)
        else:
            ip4, ip6 = self.parse_file(file=f)
        
        v4_bk, v6_bk = 'N/A', 'N/A'
        if self.using_v4:
            log.info("Backing up old IPv4 rules...")
            v4_bk = self.backup_rules('v4')
            log.info("Backed up current IPv4 rules at %s", v4_bk)
            log.info("Loading IPv4 rules into iptables from file/stream %s", f)
            load_rules(rules=ip4, ipver='v4')

        if self.using_v6:
            log.info("Backing up old IPv6 rules...")
            v6_bk = self.backup_rules('v6')
            log.info("Backed up current IPv6 rules at %s", v6_bk)
            log.info("Loading IPv6 rules into iptables from file/stream %s", f)
            load_rules(rules=ip6, ipver='v6')
        
        log.info("Finished loading rules successfully :)")
        
        def restore_rules():
            if self.using_v4:
                err(f"Restoring IPv4 rules from {v4_bk} ...")
                load_rules(v4_bk, 'v4')
            if self.using_v6:
                err(f"Restoring IPv6 rules from {v6_bk} ...")
                load_rules(v6_bk, 'v6')
        
        if confirm:
            err("Just in-case something went wrong, you need to confirm whether you're still able to connect to this system or not.")
            err(f"If you don't answer within {timeout} seconds, we'll rollback to your old iptables rules.")
            timed, ans = timeout_input("Keep these rules? (y/N)", timeout=timeout)
            
            if timed == 0:
                if ans.lower() in ["y", "ye", "yes"]:
                    return err("You said yes. Keeping your new rules :)\n")
                err("You didn't say yes, so we're going to assume something is wrong, and will rollback your rules.\n")
            else:
                err(f"No response after {timeout} seconds... automatically rolling back rules to be safe.\n")
            restore_rules()
            err("Finished rolling back rules.")
    
    def parse(self, file=None, output=None, overwrite=False):
        f = self.input_file if file is None else file
        custom_out = output is not None
        output = self.output_file if output is None else output

        self.output_stream = self._get_stream(direction='out', dest=output, overwrite=overwrite)
        self.output_stream4, self.output_stream6 = self.output_stream, self.output_stream

        if not custom_out:
            if self.output_file4 is not None and self.output_file4 != self.output_file:
                self.output_stream4 = self._get_stream(direction='out', dest=self.output_file4, overwrite=overwrite)
            
            if self.output_file6 is not None and self.output_file6 != self.output_file:
                self.output_stream6 = self._get_stream(direction='out', dest=self.output_file6, overwrite=overwrite)

        if f == '-' or (empty(f) and not sys.stdin.isatty()):
            self.input_stream = sys.stdin
            ip4, ip6 = self.parse_stream(stream=self.input_stream)
        elif empty(f):
            return parser.error('Error! The following arguments are required: file')
        else:
            ip4, ip6 = self.parse_file(file=f)
        
        self.rules_v4, self.rules_v6 = ip4, ip6

        start_line = self.gen_start_line(filename=f)
        if self.using_v4:
            w = lambda r: self.output_rule(r, dest=self.output_stream4)
            w(start_line)
            w('# --- IPv4 Rules --- #')

            for line in ip4:
                w(line)

            w('# --- End IPv4 Rules --- #')
        
        if self.output_file4 == self.output_file6:
            self.output_rule("\n#############################\n", dest=self.output_stream)

        if self.using_v6:
            w = lambda r: self.output_rule(r, dest=self.output_stream6)
            w(start_line)
            w('# --- IPv6 Rules --- #')

            for line in ip6:
                w(line)

            w('# --- End IPv6 Rules --- #')
    
    @staticmethod
    def output_rule(rule: str, dest: TextIOWrapper = sys.stdout):
        if dest == '-':
            return print(rule)
        
        dest.write(rule + "\n")
    
    def print_rules(ip4: list = None, ip6: list = None, ipver='both'):
        pass

    def _cleanup(self):
        cls_name = self.__class__.__name__
        if self.input_stream is not None:
            try:
                if self.input_stream != sys.stdout:
                    self.input_stream.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.input_stream ...")
            self.input_stream = None
        
        if self.output_stream is not None:
            try:
                if self.output_stream != sys.stdout:
                    self.output_stream.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.output_stream ...")
            self.output_stream = None
        
        if self.output_stream4 is not None:
            try:
                if self.output_stream4 != sys.stdout:
                    self.output_stream4.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.output_stream4 ...")
            self.output_stream4 = None
        
        if self.output_stream6 is not None:
            try:
                if self.output_stream6 != sys.stdout:
                    self.output_stream6.close()
            except Exception:
                log.exception(f"Error while closing {cls_name}.output_stream6 ...")
            self.output_stream6 = None
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self._cleanup()
    
    def __del__(self):
        self._cleanup()
        del self.output_stream
        del self.output_stream4
        del self.output_stream6
        del self.input_stream


def print_rules(ip4: list = None, ip6: list = None, ipver='both'):
    ip4, ip6 = [] if not ip4 else ip4, [] if not ip6 else ip6

    if ipver.lower() in ['4', 'v4', 'ipv4', 'both'] and len(ip4) > 0:
        print('# --- IPv4 Rules --- #')
        for l in ip4:
            print(l)
        print('# --- End IPv4 Rules --- #')
    print()
    if ipver.lower() in ['6', 'v6', 'ipv6', 'both'] and len(ip6) > 0:
        print('# --- IPv6 Rules --- #')
        for l in ip6:
            print(l)
        print('# --- End IPv6 Rules --- #')


def ap_parse(opt):
    k = RuleOutput(opt)
    k.parse()


def ap_reload(opt):
    f = opt.file
    k = RuleOutput(opt)
    k.load(file=f, confirm=opt.confirm, timeout=int(opt.confirm_timeout))


def ap_repl(opt):
    repl_main(files=opt.files)


def ap_install_service(opt):
    if not is_root():
        err(f"\nERROR: You must run '{sys.argv[0]} install_service' as root.")
        err(f"Try running: 'sudo {sys.argv[0]} install_service'\n")
        return sys.exit(1)
    err(f"\nCopying {SERVICE_FILE} into {SERVICE_FILE_DEST}")
    copyfile(SERVICE_FILE, SERVICE_FILE_DEST)
    err("\nReloading systemd with daemon-reload")
    run_prog_ex('systemctl', 'daemon-reload')
    err("\nEnabling pyrewall service")
    run_prog_ex('systemctl', 'enable', 'pyrewall.service')
    try:
        err("Starting pyrewall service")
        run_prog_ex('systemctl', 'start', 'pyrewall.service')
    except ReturnCodeError:
        err(f"Something went wrong starting Pyrewall.")
        err(f"If you don't yet have a master Pyrewall rules file, e.g. /etc/pyrewall/rules.pyre - then it's most likely")
        err(f"just '{sys.argv[0]} load -n' failing to find a valid master rules file.")
        err(f"You can run 'journalctl -u pyrewall' to see the logs from the service.")
        return sys.exit(1)
    
    err("Successfully installed the Pyrewall service. Your master rules file will now be auto-loaded on boot.\n")
    return sys.exit(0)


sp = parser.add_subparsers()

parse_sp = sp.add_parser('parse', description=CMD_DESC['parse'])
parse_sp.add_argument('file', default=None, help='Pyrewall file to parse', nargs='?')
parse_sp.add_argument(
    '-i', type=str, default='both', dest='ipver',
    help='4 = Output only IPv4 config, 6 = Output only IPv6 config, both = Output both configurations (default)'
)

parse_sp.add_argument(
    '--output', '-o', type=str, default='-', dest='output',
    help='Output the IPTables rules lines to this file (default "-" (stdout))'
)

parse_sp.add_argument(
    '--output6', '-o6', type=str, default=None, dest='output6',
    help='Output only the IPv6 IPTables rules lines to this file (defaults to value of shared "--output")'
)

parse_sp.add_argument(
    '--output4', '-o4', type=str, default=None, dest='output4',
    help='Output only the IPv4 IPTables rules lines to this file (defaults to value of shared "--output")'
)


parse_sp.set_defaults(func=ap_parse)

reload_sp = sp.add_parser('load', description=CMD_DESC['load'])
reload_sp.add_argument(
    '-i', type=str, default='both', dest='ipver',
    help='4 = Output only IPv4 config, 6 = Output only IPv6 config, both = Output both configurations (default)'
)
reload_sp.add_argument(
    '-t', '--timeout', type=int, default=15, dest='confirm_timeout',
    help='(default: 15) Amount of seconds to wait for user to confirm rules are working safely, before automatically rolling back'
)
reload_sp.add_argument(
    '-n', '--noninteractive', '--no-confirm', dest='confirm', action='store_false', default=True,
    help='Do not prompt to confirm whether or not to keep the newly loaded rules',
)
reload_sp.add_argument('file', help='Pyrewall file to (re-)load into IPTables', default=None, nargs='?')

reload_sp.set_defaults(func=ap_reload, confirm=True)

parse_repl = sp.add_parser('repl', description=CMD_DESC['parse'])
parse_repl.add_argument('files', help='Optionally read these Pyrewall file(s) into the REPL in order', nargs='*')
parse_repl.set_defaults(func=ap_repl)

install_service_sp = sp.add_parser('install_service', description=CMD_DESC['install_service'])
install_service_sp.set_defaults(func=ap_install_service)

args = parser.parse_args()

# Resolves the error "'Namespace' object has no attribute 'func'
# Taken from https://stackoverflow.com/a/54161510/2648583
try:
    func = args.func
    func(args)
except AttributeError:
    if not sys.stdin.isatty():
        parse_stdin()
        sys.exit(0)
    parser.error('Too few arguments')
    sys.exit(1)

