#!/usr/bin/env python3
import os
import rospy
from yag_slam.graph_slam import GraphSlam, make_near_scan_visitor
from yag_slam.scan_matching import Scan2DMatcherCpp
from yag_slam.graph import do_breadth_first_traversal
from karto_scanmatcher import create_occupancy_grid, Pose2
from yag_slam.models import LocalizedRangeScan
from yag_slam.splicing import map_to_graph
from tiny_tf.tf import Transform
from tiny_tf.transformations import euler_from_quaternion, quaternion_from_euler
from slam_toolbox_msgs.srv import SerializePoseGraph, SerializePoseGraphResponse
import tf2_geometry_msgs
from sensor_msgs.msg import LaserScan
from geometry_msgs.msg import Pose, TransformStamped
from nav_msgs.msg import OccupancyGrid, MapMetaData
from nav_msgs.srv import GetMapResponse, GetMap
import tf2_ros
from collections import namedtuple
import time
import cv2
from threading import Thread, Lock
import traceback
from queue import Queue
import argh
import yaml


def tftopose2(msg):
    t = msg.transform.translation
    r = msg.transform.rotation
    return Pose2(t.x, t.y, euler_from_quaternion((r.x, r.y, r.z, r.w))[-1])

def tftoTransform(msg):
    t = msg.transform.translation
    r = msg.transform.rotation
    return Transform(t.x, t.y, t.z, r.x, r.y, r.z, r.w)


def pose2toPose(pose):
    p = Pose()
    q = p.orientation
    q.x, q.y, q.z, q.w = quaternion_from_euler(0, 0, 0)
    p.position.x = pose.x
    p.position.y = pose.y
    return p


def TransformToTfTransform(xform, frame_id, child_frame_id):
    t = TransformStamped()
    t.header.stamp = rospy.Time.now()
    t.header.frame_id = frame_id
    t.child_frame_id = child_frame_id
    ts = t.transform.translation
    ts.x = xform.x; ts.y = xform.y
    tr = t.transform.rotation
    tr.x, tr.y, tr.z, tr.w = xform.quaternion
    return t


def _p(name, default):
    val = rospy.get_param(name, default)
    print("Queried param {} with default {} and got {}".format(name, default, val))
    return val


class YAGSlamRos1(Thread):
    def __init__(self, base_map_path=None, x=None, y=None, th=None):
        Thread.__init__(self)
        self.setup = False
        self.tfb = tf2_ros.Buffer()
        self.tbr = tf2_ros.TransformBroadcaster()
        self.tfl = tf2_ros.TransformListener(self.tfb)

        self.save_graph_service = rospy.Service('yag_slam/save_graph', SerializePoseGraph, self.save_graph)
        print(base_map_path, x, y, th)

        self.odom_frame = _p('~odom_frame', 'odom')
        self.map_frame = _p('~map_frame', 'map')
        self.sensor_frame = _p('~sensor_frame', 'base_laser_link')
        self.min_distance = _p('~min_distance', 0.5)
        self.min_rotation = _p('~min_rotation', 0.5)
        self.loop_search_min_chain_size = _p('~loop_search_min_chain_size', 10)
        self.loop_search_distance = _p('~loop_search_distance', 4.0)
        self.min_response_coarse = _p('~min_response_coarse', 0.6)
        self.min_response_fine = _p('~min_response_fine', 0.7)
        self.range_threshold_for_map = _p('~range_threshold_for_map', 12)
        self.range_threshold = _p('~range_threshold', 20)
        self.scan_buffer_len = _p('~scan_buffer_len', 10)
        self.map_resolution = _p('~map_resolution', 0.05)

        self.scan_config = None
        self.mapper = None
        self.last_pose = None
        self.scans = []

        self.initial_xform = None if (x is None or y is None or th is None) else (x, y, th)

        if base_map_path:
            print("WARNING: A base map was selected, forcing loop search min_chain size to be 2")
            self.loop_search_min_chain_size = 2
            self.min_response_coarse = 0.25
            self.min_response_fine = 0.35
            self.ingest_base_map(base_map_path)

        self.map_pub = rospy.Publisher("/map", OccupancyGrid, latch=True)
        self.map_meta_pub = rospy.Publisher("/map_metadata", MapMetaData)
        self.scan_sub = rospy.Subscriber("/scan", LaserScan, self.process_scan)
        self.map_srv = rospy.Service('dynamic_map', GetMap, self._map_callback)
        self.last_map_update_time = time.time()

        self.queue = Queue()
        self.map_queue = Queue()
        self.daemon = True
        self.map_thread = Thread(target=self._send_map)
        self.map_thread.daemon = True
        self.map_thread.start()
        self.start()

    def save_graph(self, trigger_cmd):
        out = self.mapper.binarize()
        if trigger_cmd is not None:
            fixed_path = trigger_cmd.filename
        else:
            fixed_path = "/tmp/map.graph"
        with open(fixed_path, "wb") as ff:
            print("Saving graph at {}".format(fixed_path))
            ff.write(out)
        return SerializePoseGraphResponse()

    def ingest_base_map(self, base_map_path):
        # LAST_2D_ROBOT_POSE_IN_MAP_FRAME
        # CURRENT_STATIC_MAP
        # ROBOT_CURRENT_LOCALE
        with open(base_map_path) as ff:
            data = yaml.safe_load(ff)
        image_path = os.path.dirname(base_map_path) + "/" + data["image"]
        im = cv2.imread(image_path)[::-1, :, 0].copy()

        self.setup_mapper()

        self.mapper = map_to_graphslam(self.mapper, im, data["resolution"], [data["origin"][0], data["origin"][1]], density=5)
        for ii, v in enumerate(self.mapper.graph.vertices):
            v.obj.num = ii
            self.scans.append(v.obj)

        # to fix opt issues for now
        self.mapper = GraphSlam.deserialize(self.mapper.serialize())

    def setup_mapper(self):
        seq_scan_matcher_config = {
            "angle_variance_penalty": _p("~angle_variance_penalty", 0.349),
            "distance_variance_penalty": _p("~distance_variance_penalty", 0.3),
            "coarse_search_angle_offset": _p("~coarse_search_angle_offset", 0.349),
            "coarse_angle_resolution": _p("~coarse_angle_resolution", 0.0349),
            "fine_search_angle_resolution": _p("~fine_search_angle_resolution", 0.00349),
            "use_response_expansion": _p("~use_response_expansion", True),
            "range_threshold": _p("~range_threshold", 20),
            "minimum_angle_penalty": _p("~minimum_angle_penalty", 0.9),
            "search_size": _p("~sequential_matching_search_size", 0.3),
            "resolution": _p("~sequential_matching_resolution", 0.01),
            "smear_deviation": _p("~sequential_matching_smear_deviation", 0.07)
        }

        loop_scan_matcher_config = seq_scan_matcher_config.copy()
        loop_scan_matcher_config.update({
            "search_size": _p("~loop_matching_search_size", 4.0),
            "resolution": _p("~loop_matching_resolution", 0.05),
            "smear_deviation": _p("~loop_matching_smear_deviation", 0.03),
        })

        seq_matcher = Scan2DMatcherCpp(config_dict=seq_scan_matcher_config)
        loop_matcher = Scan2DMatcherCpp(loop_scan_matcher_config, loop=True)

        self.mapper = GraphSlam(
            seq_matcher, loop_matcher,
            loop_search_min_chain_size=self.loop_search_min_chain_size,
            scan_buffer_len=self.scan_buffer_len,
            min_response_coarse=self.min_response_coarse,
            min_response_fine=self.min_response_fine)

    def _map_callback(self, req):
        resp = GetMapResponse()
        resp.map = self._make_map()
        return resp

    def _make_map(self, msg=True):
        grid = create_occupancy_grid([v.obj._scan for v in self.mapper.graph.vertices], self.map_resolution, self.range_threshold_for_map)
        im = grid.image

        # do cleanup
        static_only = 255-im.copy()
        static_only[static_only < 200] = 0
        num_conn, mask, stats, position = cv2.connectedComponentsWithStats(static_only)
        for ii, stat in enumerate(stats):
            if stat[-1] < 5:
                im[mask == ii] = 255

        im = im.astype('int16')
        im[im == 0] = 100
        im[im == 200] = -1
        im[im == 255] = 0
        if not msg:
            return im
        map_msg = OccupancyGrid()
        map_msg.info.resolution = self.map_resolution
        map_msg.info.height, map_msg.info.width,  = grid.height, grid.width
        map_msg.data = im.flatten().astype('int8').tolist()
        map_msg.info.origin = pose2toPose(grid.offset)
        map_msg.header.frame_id = self.map_frame
        # print("Made map")
        return map_msg

    def _send_map(self):
        while True:
            res = self.map_queue.get()
            map_msg = self._make_map()
            self.map_pub.publish(map_msg)
            self.map_meta_pub.publish(map_msg.info)
            if res == "true":
                self.save_graph(None)

    def run(self):
        map_counter = 0
        while True:
            msg, pose, invert_scan = self.queue.get()

            ranges = msg.ranges[::-1] if invert_scan else msg.ranges

            scan = LocalizedRangeScan(
                ranges, msg.angle_min, msg.angle_max, msg.angle_increment,
                msg.range_min, msg.range_max, self.range_threshold, pose.x, pose.y, pose.yaw)

            if self.initial_xform is not None:
                odom_pose_bkup = scan.odom_pose
                scan.odom_pose = Transform.from_xyt(*(self.initial_xform))
                scan.corrected_pose = Transform.from_xyt(*(self.initial_xform))
                print("Setting initial transform to {}".format(self.initial_xform))

            if not self.mapper.running_scans and self.scans and self.initial_xform:
                print("Ingesting a new scan for connecting to existing map")
                # first scan when splicing into a map
                slam_fake = self.mapper
                scan.num = max([v.obj.num for v in self.mapper.graph.vertices]) + 1

                nearby_scans = slam_fake.search.crude_radius_search(scan.odom_pose, 5)
                res = slam_fake.seq_matcher.match_scan(scan, [v.obj for v in nearby_scans], do_fine=True)
                scan.corrected_pose = res.best_pose
                slam_fake.add_vertex(scan)
                slam_fake.link_scans(scan, nearby_scans[0].obj, None, res.covariance)
                slam_fake.running_scans.append(scan)
                closed = True
                self.save_graph(None)
            else:
                res, closed = self.mapper.process_scan(scan)

            if self.initial_xform is not None:
                scan.odom_pose = odom_pose_bkup
                self.initial_xform = None

            if len(self.mapper.running_scans) == 1:
                self.last_map_update_time = time.time()
                self.map_queue.put(True)

            map_counter += 1
            if (map_counter >= 5 or closed) and self.map_queue.qsize() == 0:
                self.last_map_update_time = time.time()
                self.map_queue.put("true")
                map_counter = 0

    def process_scan(self, msg):
        if not self.mapper:
            self.setup_mapper()

        if (time.time() - self.last_map_update_time >= 5):
            self.last_map_update_time = time.time()
            self.map_queue.put(True)

        try:
            odom_to_sensor = self.tfb.lookup_transform(
                self.odom_frame, self.sensor_frame, msg.header.stamp, rospy.Duration(0.1))
        except Exception:
            traceback.print_exc()
            print("Looking up tf for {} from {} failed at {}".format(self.sensor_frame, self.odom_frame, time.time()))
            return

        if self.mapper.running_scans:
            ls = self.mapper.running_scans[-1]
            odom_to_map = (Transform.from_pose2d(ls.odom_pose) + Transform.from_pose2d(ls.corrected_pose).inverse())
            self.tbr.sendTransform(TransformToTfTransform(odom_to_map.inverse(), self.map_frame, self.odom_frame))

        pose = tftopose2(odom_to_sensor)

        if not self.last_pose:
            self.last_pose = pose
        else:
            p = pose
            l = self.last_pose
            if ((p.x - l.x)**2 + (p.y - l.y)**2 < self.min_distance**2 and
                abs(p.yaw - l.yaw) < self.min_rotation):
                return
            self.last_pose = pose

        # if the z of the below is -ve, then sensor is upside down
        up_xform = tftoTransform(odom_to_sensor) + Transform(0, 0, 100, 0, 0, 0, 1)

        if self.queue.qsize() > 1:
            print("{}: Current queue size {}".format(time.time(), self.queue.qsize()))

        self.queue.put((msg, pose, up_xform.z < 0))

def main(*, base_map_path=None, x=0.0, y=0.0, th=0.0):
    rospy.init_node('yag_slam', anonymous=False)
    mapper = YAGSlamRos1(base_map_path, x, y, th)
    rospy.spin()

if __name__ == '__main__':
    argh.dispatch_command(main)
