#!/bin/python3
# -*- coding: utf-8 -*-
#
# This file is part of INGInious. See the LICENSE and the COPYRIGHTS files for
# more information about the licensing of this file.

import json
import logging
import sys
import shutil
import os
import os.path
import inginious_container_api.feedback
from inginious_container_api.run_types import run_types
from inginious_container_api.utils import start_ssh_server, ssh_wait, execute_process
import tarfile
import base64
import msgpack
import asyncio
import struct
import zmq
import zmq.asyncio


ELF_MAGIC = b"\x7F\x45\x4C\x46"
SHEBANG_MAGIC = b"\x23\x21"


class INGIniousMainRunner:
    def __init__(self, ctx, loop):
        self._ctx = ctx
        self._loop = loop
        self._logger = logging.getLogger("inginious.container")
        self._logger.info("Hello")
        self.stdout = None
        self.intern = None
        self.internal_socket_send = None
        self.running_student_container = None
        self.run_as_root = False

    def copy_tree(self, src, dst, symlinks=False, ignore=None):
        """ Custom copy tree to allow to copy into existing directories """
        for item in os.listdir(src):
            s = os.path.join(src, item)
            d = os.path.join(dst, item)
            if os.path.isdir(s):
                shutil.copytree(s, d, symlinks, ignore)
            else:
                shutil.copy2(s, d)

    def set_directory_rights(self, path):
        os.chmod(path, 0o777)
        os.chown(path, 4242, 4242)
        for root, dirs, files in os.walk(path):
            for d in dirs:
                os.chmod(os.path.join(root, d), 0o777)
                os.chown(os.path.join(root, d), 4242, 4242)
            for f in files:
                os.chmod(os.path.join(root, f), 0o777)
                os.chown(os.path.join(root, f), 4242, 4242)

    def b64tararchive(self):
        with tarfile.open('/tmp/archive.tgz', "w:gz") as tar:
            tar.add('/archive/', arcname='/')

        with open('/tmp/archive.tgz', "rb") as tar:
            encoded_string = base64.b64encode(tar.read())

        return encoded_string.decode('utf-8')

    async def stdio(self):
        """
        :return: (reader, writer) connected to stdin/stdout
        """
        loop = asyncio.get_event_loop()

        reader = asyncio.StreamReader()
        reader_protocol = asyncio.StreamReaderProtocol(reader)

        writer_transport, writer_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin,
                                                                          os.fdopen(1, 'wb'))
        writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, loop)

        await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin)

        return reader, writer

    async def handle_stdin(self, reader: asyncio.StreamReader):
        """
        Handle messages from the agent
        """
        try:
            while not reader.at_eof():
                buf = bytearray()
                while len(buf) != 4 and not reader.at_eof():
                    buf += await reader.read(4 - len(buf))
                if reader.at_eof():
                    continue
                length = struct.unpack('!I', bytes(buf))[0]
                buf = bytearray()
                while len(buf) != length and not reader.at_eof():
                    buf += await reader.read(length - len(buf))
                if reader.at_eof():
                    continue
                message = msgpack.unpackb(bytes(buf), use_list=False)
                await self.handle_stdin_message(message)
        except (asyncio.CancelledError, KeyboardInterrupt):
            return
        except:
            self._logger.exception("Exception occured while reading stdin")
            os._exit(1)  # DIE!

    async def send_intern_message(self, msg):
        await self.internal_socket_send.send(msgpack.dumps(msg, use_bin_type=True))
        await self.internal_socket_send.recv()  # ignore return

    async def serve(self):
        self._logger.info("Starting serve")
        stdin_sr, self.stdout = await self.stdio()
        self._loop.create_task(self.handle_stdin(stdin_sr))

        self.intern = self._ctx.socket(zmq.ROUTER)
        self.intern.bind("ipc:///sockets/main.sock")

        self.internal_socket_send = self._ctx.socket(zmq.REQ)
        self.internal_socket_send.connect("ipc:///sockets/main.sock")

        self.running_student_container = {}  # socket_id : addr

        poller = zmq.asyncio.Poller()
        poller.register(self.intern, zmq.POLLIN)

        self._logger.info("Serving...")
        try:
            while True:
                socks = await poller.poll()
                socks = dict(socks)

                # New message from process in the container
                if self.intern in socks:
                    addr, empty, msg_enc = await self.intern.recv_multipart()
                    msg = msgpack.loads(msg_enc, use_list=False, strict_map_key=False)
                    close = await self.handle_intern_message(addr, msg)
                    if close:
                        return
        except (asyncio.CancelledError, KeyboardInterrupt):
            return
        except:
            self._logger.exception("An exception occured while serving")

    async def write_stdout(self, o):
        msg = msgpack.dumps(o, use_bin_type=True)
        self.stdout.write(struct.pack('!I', len(msg)))
        self.stdout.write(msg)
        await self.stdout.drain()

    async def handle_stdin_message(self, message):
        self._logger.info("received message %s", message["type"])
        try:
            if message["type"] == "start":
                self._loop.create_task(self.start_cmd(message))
            if message["type"] == "run_student_started":
                if message["socket_id"] in self.running_student_container:
                    addr = self.running_student_container[message["socket_id"]]
                    await self.intern.send_multipart(
                        [addr, b'', msgpack.dumps({"type": "run_student_started", "container_id": message["container_id"]}, use_bin_type=True)])
            if message["type"] == "run_student_retval":
                if message["socket_id"] in self.running_student_container:
                    addr = self.running_student_container[message["socket_id"]]
                    del self.running_student_container[message["socket_id"]]
                    await self.intern.send_multipart(
                        [addr, b'', msgpack.dumps({"type": "run_student_retval", "retval": message["retval"]},
                                                  use_bin_type=True)])
            if message["type"] in ["stdout", "stderr"]:
                if message["socket_id"] in self.running_student_container:
                    addr = self.running_student_container[message["socket_id"]]
                    await self.intern.send_multipart([addr, b'', msgpack.dumps(message, use_bin_type=True)])

        except:
            self._logger.exception("An exception occured while reading stdin")

    async def handle_intern_message(self, addr, message):
        self._logger.info("received intern message %s", message)
        try:
            if message["type"] == "ssh_debug":
                # copy the dict manually to ensure the corectness of the message
                await self.write_stdout({"type": "ssh_debug", "ssh_user": message["ssh_user"], "ssh_key": message["ssh_key"]})
                await self.intern.send_multipart([addr, b'', msgpack.dumps({"type": "ok"}, use_bin_type=True)])
                return False
            if message["type"] == "ssh_student":
                # copy the dict manually to ensure the corectness of the message
                await self.write_stdout({"type": "ssh_student", "ssh_user": message["ssh_user"], "ssh_key": message["ssh_key"], "container_id": message["container_id"]})
                await self.intern.send_multipart([addr, b'', msgpack.dumps({"type": "ok"}, use_bin_type=True)])
                return False
            if message["type"] == "run_student":
                # copy the dict manually to ensure the correctness of the message
                self.running_student_container[message["socket_id"]] = addr
                await self.write_stdout({"type": "run_student", "environment": message["environment"],
                                         "time_limit": message["time_limit"],
                                         "hard_time_limit": message["hard_time_limit"],
                                         "memory_limit": message["memory_limit"],
                                         "share_network": message["share_network"],
                                         "socket_id": message["socket_id"],
                                         "ssh": message["ssh"],
                                         "run_as_root": message["run_as_root"]})
                return False
            if message["type"] == "run_student_init":  # This message may be sent from run_student to transfer to the student_container via the agent (only when starting a kata student_container)
                await self.write_stdout({"type": "run_student_init",
                                         "socket_id": message["socket_id"],
                                         "student_container_id": message["student_container_id"],
                                         "command": message["command"],
                                         "teardown_script": message["teardown_script"],
                                         "working_dir": message["working_dir"],
                                         "ssh": message["ssh"],
                                         "user": message["user"]
                                         })
                await self.intern.send_multipart([addr, b'', msgpack.dumps({"type": "dummy_message"}, use_bin_type=True)])  # ping pong answer
                return False
            if message["type"] == "dummy_message":
                # ignore, just a dummy message
                return False
            if message["type"] in ["student_signal", "stdin"]:
                await self.write_stdout(message)
                await self.intern.send_multipart([addr, b'', msgpack.dumps({"type": "dummy_message"}, use_bin_type=True)])  # ping pong answer
                return False
            if message["type"] == "done":
                await self.intern.send_multipart([addr, b'', msgpack.dumps({"type": "ok"}, use_bin_type=True)])
                return True
            return False
        except:
            self._logger.exception("Exception occured while handling an internal message")

    async def start_cmd(self, data):
        try:
            result = await self._loop.run_in_executor(None, lambda: self._start_cmd_sync(data))
            await self.write_stdout({"type": "result", "result": result})
        except Exception as e:
            self._logger.exception("Exception while running start_cmd")
            await self.write_stdout({"type": "result", "result": {
                "result": "crash",
                "text": "Exception occured in container ({})".format(e),
                "problems": {}
            }})

        await self.send_intern_message({"type": "done"})

    def write_inputdata(self, data):
        unwritable_idx = 0
        if "input" in data:
            for d in data["input"].values():
                try:
                    if d.keys() == {"filename", "value"}:
                        filename = "/.__input/__inputfile_{}".format(unwritable_idx)
                        unwritable_idx += 1
                        open(filename, 'wb').write(d["value"])
                        d["value"] = filename
                except:
                    pass

        f = open('/.__input/__inputdata.json', 'w')
        f.write(json.dumps(data))
        f.close()

    def _detect_run_file(self, run_cmd=None):
        """ Detect the run file. The order of the resolution is as follows:
        - If run_cmd is not None, then run it.
        - If /task/run exists, then run it.
        - For all runtypes (with extension `ext` and command `cmd`) inside the runtypes dictionnary:
            - if `/task/run.ext` exists, then run  `cmd /task/run.ext`
        - If /course/run exists, then run it
        - For all runtypes (with extension `ext` and command `cmd`) inside the runtypes dictionnary:
            - if `/course/run.ext` exists, then run  `cmd /task/run.ext`
        - Produce an error.

        :returns: a list containing the command to be run by execute_process, or None if no run file is found.
        """

        if run_cmd is not None:
            return run_cmd.split(" ")
        for prefix in ["task", "course/common", "course"]:
            if os.path.exists("/{}/run".format(prefix)):
                return ["/{}/run".format(prefix)]
            for ext, cmd in run_types.items():
                if os.path.exists("/{}/run.{}".format(prefix, ext)):
                    return cmd + ["/{}/run.{}".format(prefix, ext)]

        return None

    def _start_cmd_sync(self, data):
        self._logger.info("starting run")
        # Determining if debug mode or not
        debug = (sys.argv[1:] and sys.argv[1] == '--debug') or data.get("debug", False)

        # Run as root? Shared kernel?
        self.run_as_root = data.get('run_as_root', False)
        self.shared_kernel = data.get('shared_kernel', False)

        # Create input data directory
        if not os.path.exists("/.__input"):
            os.mkdir("/.__input")
        self.write_inputdata(data)

        # Create output directory
        if not os.path.exists("/.__output"):
            os.mkdir("/.__output")

        # Create archive directory
        if not os.path.exists("/archive"):
            os.mkdir("/archive")

        # Touch __feedback.json (to set the rights)
        open('/.__output/__feedback.json', 'w').close()

        # Verify that task directory exists
        if not os.path.exists("/task"):
            os.mkdir("/task")

        # Assert that the directory .ssh does not exists
        if os.path.exists("/task/.ssh"):
            shutil.rmtree("/task/.ssh")

        # Set rights on some files
        self.set_directory_rights("/tmp")
        self.set_directory_rights("/task")
        self.set_directory_rights("/.__output")
        self.set_directory_rights("/archive")

        # Set a flag file indicating if the runtime allows run_student to be run
        if self.shared_kernel:
            open('/.__input/__shared_kernel', 'w').close()

        ok_to_start = True

        # Add some elements to /etc/hosts and /etc/resolv.conf if needed
        system_files = {"hosts": ("/etc/hosts", True), "resolv.conf": ("/etc/resolv.conf", False)}
        for name, (spath, append) in system_files.items():
            if os.path.exists(os.path.join('/task/systemfiles/', name)):
                try:
                    open(spath, 'ab' if append else 'wb').write(
                        b'\n' + open(os.path.join('/task/systemfiles/', name), 'rb').read())
                except IOError:
                    inginious_container_api.feedback.set_global_result('crash')
                    inginious_container_api.feedback.set_global_feedback(
                        "Cannot read/use %s" % os.path.join('/task/systemfiles/', name))
                    ok_to_start = False
        os.chown('/etc/hosts', 0, 4242)
        os.chown('/etc/resolv.conf', 0, 4242)
        os.chmod('/etc/hosts', 0o664)
        os.chmod('/etc/resolv.conf', 0o664)

        # Ensure the worker can read the main socket
        try:
            os.chown('/sockets/main.sock', 4242, 4242)
            os.chmod('/sockets/main.sock', 0o775)
        except OSError:
            pass

        # Launch everything
        stdout, stderr = b"", b""
        if not ok_to_start:
            pass  # do not start ;-)
        elif debug != "ssh":  # normal start
            run_final_cmd = self._detect_run_file(data.get("run_cmd"))

            if run_final_cmd is not None:
                os.chdir("/task")
                try:
                    user = "root" if self.run_as_root else "worker"
                    stdout, stderr = execute_process(run_final_cmd, user=user)
                except:
                    inginious_container_api.feedback.set_global_result('crash')
                    inginious_container_api.feedback.set_global_feedback(
                        "An error occured while running the grading script. It is possible that it is non-executable or made a timeout")
            else:
                inginious_container_api.feedback.set_global_result('crash')
                inginious_container_api.feedback.set_global_feedback("'/task/run' could not be found")
        else:
            # Start the ssh server
            ssh_user, password = start_ssh_server("root" if self.run_as_root else "worker")
            # Send ssh information
            self._loop.call_soon_threadsafe(asyncio.ensure_future, self.send_intern_message({
                "type": "ssh_debug",
                "ssh_user": ssh_user,
                "ssh_key": password}))
            # Wait for user to connect and leave (user can stay connected max 30 min)
            status = ssh_wait(ssh_user, 30*60)
            if status == 253:
                inginious_container_api.feedback.set_global_result('crash')
                inginious_container_api.feedback.set_global_feedback('[SSH Debug] Nobody connected to the SSH server.')
            stdout, stderr = b"", b""

        # Produce feedback
        feedback = inginious_container_api.feedback.get_feedback()
        if not feedback:
            result = {"result": "crash", "text": "No feedback was given !", "problems": {}, "tests": {}}
            if debug:
                result['stdout'] = stdout.decode('utf-8', 'replace')
                result['stderr'] = stderr.decode('utf-8', 'replace')
            self.set_directory_rights('/task')
            self._logger.info("returning results")
            return result
        else:
            if debug:
                feedback['stdout'] = stdout.decode('utf-8', 'replace')
                feedback['stderr'] = stderr.decode('utf-8', 'replace')
            feedback['archive'] = self.b64tararchive()
            self.set_directory_rights('/task')
            self._logger.info("returning results")
            return feedback

    @classmethod
    def start(cls):
        logger = logging.getLogger("inginious")
        logger.setLevel(logging.DEBUG)
        ch = logging.StreamHandler()
        ch.setLevel(logging.DEBUG)
        formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
        ch.setFormatter(formatter)
        logger.addHandler(ch)

        context = zmq.asyncio.Context()
        loop = zmq.asyncio.ZMQEventLoop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(INGIniousMainRunner(context, loop).serve())
        loop.close()
        context.destroy(1)


INGIniousMainRunner.start()
