#!python
# coding: utf-8

import sys
import argparse

from twisted.internet import defer
from twisted.internet import protocol
from twisted.internet import reactor
from twisted.python import log

sys.path.append(".")

import marionette
import marionette.conf
import marionette.dsl

mar_files = marionette.dsl.list_mar_files('client')
ver_string = "Marionette proxy client.\nAvailable formats:\n"
for mar_file in mar_files:
    ver_string += " %s\n" % (mar_file)

parser = argparse.ArgumentParser(description='Marionette proxy client.',
    formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--version', action='version', version=ver_string)
parser.add_argument('--client_ip', '-cip', dest='client_ip', required=False,
    help='IP address for client to bind to')
parser.add_argument('--client_port', '-cport', dest='client_port',
    required=False, help='port for client to bind to')
parser.add_argument('--server_ip', '-sip', dest='server_ip', required=False,
    help='server IP address to connect to')
parser.add_argument('--format', '-f', dest='format', required=False,
    help='Marionette format to use for connection')
parser.add_argument('--debug', '-d', dest='debug', required=False, action='store_true',
    help='Turn on debug output')
args = parser.parse_args()

if args.server_ip != None:
    marionette.conf.set('server.server_ip', str(args.server_ip))
if args.client_ip != None:
    marionette.conf.set('client.client_ip', str(args.client_ip))
if args.client_port != None:
    marionette.conf.set('client.client_port', int(args.client_port))
if args.format != None:
    marionette.conf.set('general.format', str(args.format))
if args.debug == True:
    marionette.conf.set('general.debug', args.debug)

LOCAL_IP = marionette.conf.get('client.client_ip')
LOCAL_PORT = marionette.conf.get('client.client_port')

FORMAT = marionette.conf.get('general.format')
FORMAT_VERSION = None
if ':' in FORMAT:
    FORMAT, FORMAT_VERSION = args.format.split(':', 1)



class ProxyClient(protocol.Protocol):

    def connectionMade(self):
        log.msg("ProxyClient.connectionMade")
        self.srv_queue = defer.DeferredQueue()
        self.srv_queue.get().addCallback(self.clientDataReceived)
        self.client_stream_ = client.start_new_stream(self.srv_queue)

    def clientDataReceived(self, chunk):
        log.msg(
            "ProxyClient: writing %d bytes to original client" %
            len(chunk))

        # Convert string to bytes using latin-1 encoding (preserves byte values 0-255)
        if isinstance(chunk, str):
            chunk = chunk.encode('latin-1')
        self.transport.write(chunk)
        self.srv_queue.get().addCallback(self.clientDataReceived)

    def dataReceived(self, chunk):
        log.msg("ProxyClient: %d bytes received" % len(chunk))
        self.client_stream_.push(chunk)

    def connectionLost(self, why):
        log.msg("ProxyClient.connectionLost: " + str(why))
        self.client_stream_.terminate()


if __name__ == "__main__":
    if marionette.conf.get("general.debug"):
        log.startLogging(sys.stdout)

    client = marionette.Client(FORMAT, FORMAT_VERSION)

    factory = protocol.Factory()
    factory.protocol = ProxyClient
    reactor.listenTCP(LOCAL_PORT, factory, interface=LOCAL_IP)
    reactor.callFromThread(client.execute, reactor)

    reactor.run()
