#!python
# coding: utf-8

import sys
import argparse

from twisted.internet import defer
from twisted.internet import protocol
from twisted.internet import reactor
from twisted.python import log

sys.path.append(".")

import marionette
import marionette.conf
import marionette.dsl

mar_files = marionette.dsl.list_mar_files('server')
ver_string = "Marionette proxy server.\nAvailable formats:\n"
for mar_file in mar_files:
    ver_string += " %s\n" % (mar_file)

parser = argparse.ArgumentParser(description='Marionette proxy server.',
    formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--version', action='version', version=ver_string)
parser.add_argument('--server_ip', '-sip', dest='server_ip', required=False,
    help='IP address for client to bind to')
parser.add_argument('--proxy_port', '-pport', dest='proxy_port',
    required=False, help='port for client to bind to')
parser.add_argument('--proxy_ip', '-pip', dest='proxy_ip', required=False,
    help='server IP address to connect to')
parser.add_argument('--format', '-f', dest='format', required=False,
    help='Marionette format to use for connection')
parser.add_argument('--debug', '-d', dest='debug', required=False, action='store_true',
    help='Turn on debug output')
args = parser.parse_args()

if args.server_ip != None:
    marionette.conf.set('server.server_ip', str(args.server_ip))
if args.proxy_ip != None:
    marionette.conf.set('server.proxy_ip', str(args.proxy_ip))
if args.proxy_port != None:
    marionette.conf.set('server.proxy_port', int(args.proxy_port))
if args.format != None:
    marionette.conf.set('general.format', str(args.format))
if args.debug == True:
    marionette.conf.set('general.debug', args.debug)

LOCAL_IP = marionette.conf.get('server.server_ip')
REMOTE_IP = marionette.conf.get('server.proxy_ip')
REMOTE_PORT = marionette.conf.get('server.proxy_port')

FORMAT = marionette.conf.get('general.format')
FORMAT_VERSION = None
if ':' in FORMAT:
    FORMAT, FORMAT_VERSION = args.format.split(':', 1)


class ProxyServerProtocol(protocol.Protocol):

    def connectionMade(self):
        log.msg("ProxyServerProtocol: connected to peer")
        self.cli_queue = self.factory.cli_queue
        self.cli_queue.get().addCallback(self.serverDataReceived)

    def serverDataReceived(self, chunk):
        if chunk is False:
            self.cli_queue = None
            log.msg("ProxyServerProtocol: disconnecting from peer")
            self.factory.continueTrying = False
            self.transport.loseConnection()
        elif self.cli_queue:
            log.msg(
                "ProxyServerProtocol: writing %d bytes to peer" %
                len(chunk))

            # Convert string to bytes using latin-1 encoding (preserves byte values 0-255)
            if isinstance(chunk, str):
                chunk = chunk.encode('latin-1')
            self.transport.write(chunk)
            self.cli_queue.get().addCallback(self.serverDataReceived)
        else:
            log.msg(
                "ProxyServerProtocol: (2) writing %d bytes to peer" %
                len(chunk))
            self.factory.cli_queue.put(chunk)

    def dataReceived(self, chunk):
        log.msg(
            "ProxyServerProtocol: %d bytes received from peer" %
            len(chunk))
        self.factory.srv_queue.put(chunk)

    def connectionLost(self, why):
        log.msg("ProxyServerProtocol.connectionLost: " + str(why))
        if self.cli_queue:
            self.cli_queue = None
            log.msg("ProxyServerProtocol: peer disconnected unexpectedly")


class ProxyServerFactory(protocol.ClientFactory):
    protocol = ProxyServerProtocol

    def __init__(self, srv_queue, cli_queue):
        self.srv_queue = srv_queue
        self.cli_queue = cli_queue


class ProxyServer(object):

    def __init__(self):
        self.connector = None

    def connectionMade(self, marionette_stream):
        log.msg("ProxyServer.connectionMade")
        self.cli_queue = defer.DeferredQueue()
        self.srv_queue = defer.DeferredQueue()
        self.marionette_stream = marionette_stream
        self.srv_queue.get().addCallback(self.clientDataReceived)

        self.factory = ProxyServerFactory(self.srv_queue, self.cli_queue)
        self.connector = reactor.connectTCP(
            REMOTE_IP,
            REMOTE_PORT,
            self.factory)

    def clientDataReceived(self, chunk):
        log.msg(
            "ProxyServer.clientDataReceived: pushing %d bytes to marionette stream" %
            len(chunk))
        # chunk is bytes from srv_queue (from upstream proxy)
        # BufferOutgoing.push handles bytes correctly
        self.marionette_stream.push(chunk)
        self.srv_queue.get().addCallback(self.clientDataReceived)

    def dataReceived(self, chunk):
        log.msg("ProxyServer.dataReceived: %s bytes" % len(chunk))
        self.cli_queue.put(chunk)

    def connectionLost(self):
        log.msg("ProxyServer.connectionLost")
        self.cli_queue.put(False)
        self.connector.disconnect()


if __name__ == "__main__":
    if marionette.conf.get("general.debug"):
        log.startLogging(sys.stdout)

    server = marionette.Server(FORMAT)
    server.factory = ProxyServer

    reactor.callFromThread(server.execute, reactor)

    reactor.run()
