#!/usr/bin/env python3


''' --------------------------------------------------------------------------
20170730:
--------

Shared memory viewer.

-------------------------------------------------------------------------- '''
from xaosim.QtMain import QtMain, QApplication
from PyQt5 import QtCore, QtGui, uic
from PyQt5.QtCore import QThread
from PyQt5.QtWidgets import QLabel, QMainWindow, QFileDialog
from PyQt5.QtGui import QImage

import threading

import pyqtgraph as pg
import sys
import numpy as np
from xaosim.shmlib import shm
import matplotlib.cm as cm

import os

# =====================================================================
home = os.getenv('HOME')
conf_dir = "./" # home+'/.config/xaosim/'

# =====================================================================
# =====================================================================
myqt = 0  # to have myqt as a global variable

def main():
    global myqt
    myqt = QtMain()
    gui = MyWindow()
    myqt.mainloop()
    myqt.gui_quit()
    sys.exit()


# =====================================================================
#                               Tools
# =====================================================================
def arr2im(arr, vmin=False, vmax=False, pwr=1.0, cmap=None, gamma=1.0):
    ''' --------------------------------------------------------
    convert 2D numpy array into image for display

    limits dynamic range, power coefficient and applies colormap
    -------------------------------------------------------- '''
    arr2 = arr.astype('float')  # local array is modified
    mmin = arr2.min() if vmin is False else vmin
    mmax = arr2.max() if vmax is False else vmax
    mycmap = cm.magma if cmap is None else cmap

    arr2 -= mmin
    if mmax != mmin:
        arr2 /= (mmax-mmin)
    arr2 = arr2**pwr

    res = mycmap(arr2)
    res[:, :, 3] = gamma
    return(res)

# =====================================================================
#                          Main GUI object
# =====================================================================
args = sys.argv[1:]


class MyWindow(QMainWindow):
    ''' ------------------------------------------------------
    This is the meat of the program: the class that drives
    the GUI.
    ------------------------------------------------------ '''
    def __init__(self):
        global index
        self.mySHM = None  # handle for mmapped SHM file
        self.vmin = False
        self.vmax = False
        self.pwr = 1.0
        self.mycmap = cm.gray

        super(MyWindow, self).__init__()
        if not os.path.exists(conf_dir + 'shmimview.ui'):
            uic.loadUi('shmimview.ui', self)
        else:
            uic.loadUi(conf_dir + 'shmimview.ui', self)

        # ==============================================
        # prepare the display
        # ==============================================

        self.gView_shm.hideAxis('left')
        self.gView_shm.hideAxis('bottom')
        self.imv_data = pg.ImageItem()
        self.overlay = pg.GraphItem()
        self.gView_shm.addItem(self.imv_data)

        # ==============================================
        #             GUI widget actions
        # ==============================================
        self.dspB_disp_min.valueChanged[float].connect(self.update_vmin)
        self.dspB_disp_max.valueChanged[float].connect(self.update_vmax)
        self.chB_min.stateChanged[int].connect(self.update_vmin)
        self.chB_max.stateChanged[int].connect(self.update_vmax)
        self.chB_nonlinear.stateChanged[int].connect(self.update_nonlinear)
        self.chB_dark_sub.stateChanged[int].connect(self.update_dark_state)
        self.cmB_cbar.addItems(
            ['magma', 'gray', 'hot', 'cool', 'bone', 'jet', 'viridis'])
        self.cmB_cbar.activated[str].connect(self.update_cbar)
        self.cmB_cbar.setCurrentIndex(0)
        self.update_cbar()

        # ==============================================
        #             top-menu actions
        # ==============================================
        self.actionQuit.triggered.connect(sys.exit)
        self.actionQuit.setShortcut('Ctrl+Q')

        self.actionOpen.triggered.connect(self.load_shm)
        self.actionLoadDark.triggered.connect(self.load_shm_drk)
        self.setMinimumSize(600, 400)

        if sys.argv[1:] != []:
            # breakpoint()
            target = str(sys.argv[1])
            if os.path.isfile(target):
                self.mySHM = shm(target)
                self.array_title.setText(target)
                self.live_counter = -1
                self.naxis = self.mySHM.mtdata['naxis']
                if self.mySHM.empty:
                    self.mySHM = None
                else:
                    self.auto_resize_window()
            else:
                print("No file loaded")

            if len(sys.argv[1:]) > 1:
                try:
                    self.myslice = int(sys.argv[2])
                except ValueError:
                    print("Not a valid slice")
                    self.myslice = 0
                self.myslice = min(self.myslice,
                                   self.mySHM.mtdata['size'][-1]-1)
            else:
                self.myslice = 0

        self.myDRK = None
        # ==============================================
        self.show()

        self.timer = QtCore.QTimer()
        self.timer.timeout.connect(self.refresh_all)
        self.timer.start(100)

    # =========================================================
    def closeEvent(self, event):
        sys.exit()

    # =========================================================
    def update_cbar(self):
        cbar = str(self.cmB_cbar.currentText()).lower()
        try:
            exec('self.mycmap = cm.%s' % (cbar,))
        except AttributeError:
            self.mycmap = cm.jet

    # =========================================================
    def update_vmin(self):
        self.vmin = False
        if self.chB_min.isChecked():
            self.vmin = self.dspB_disp_min.value()

    # =========================================================
    def update_vmax(self):
        self.vmax = False
        if self.chB_max.isChecked():
            self.vmax = self.dspB_disp_max.value()

    # =========================================================
    def update_nonlinear(self):
        self.pwr = 1.0
        if self.chB_nonlinear.isChecked():
            self.pwr = 0.25

    # =========================================================
    def update_dark_state(self):
        if self.live_counter == self.mySHM.get_counter():
            self.live_counter -= 1
        pass

    # =========================================================
    def refresh_img(self):
        self.imv_data.setImage(arr2im(self.data_img.T,
                                      vmin=self.vmin, vmax=self.vmax,
                                      pwr=self.pwr,
                                      cmap=self.mycmap), border=2)

    # =========================================================
    def refresh_stats(self, add_msg=None):

        pt_levels = [0, 5, 10, 20, 50, 75, 90, 95, 99, 100]
        pt_values = np.percentile(self.data_img, pt_levels)

        msg = "<pre>\n"

        self.mySHM.read_keywords()
        if self.naxis == 3:
            msg += "imsize = %d x %d x %d\n\n" % self.mySHM.mtdata["size"]
        else:
            msg += "imsize = %d x %d\n" % self.mySHM.mtdata["size"][1:]

        for ii, kwd in enumerate(self.mySHM.kwds):
            msg += "%10s : %10s \n" % (kwd['name'], kwd['value'])

        for i, ptile in enumerate(pt_levels):
            msg += "p-tile %3d = %8.2f\n" % (ptile, pt_values[i])

        msg += "img count  = %8d\n" % (self.mySHM.get_counter(),)

        if add_msg is not None:
            msg += "%s\n" % (add_msg,)
        msg += "</pre>"
        self.lbl_stats.setText(msg)

    # =========================================================
    def auto_resize_window(self):
        imsize = self.mySHM.get_data().shape
        imratio = 1.0
        if self.naxis == 2:
            imratio = imsize[1] / imsize[0]

        vsize = self.gView_shm.geometry()  # view size

        hratio = vsize.width() / imsize[1]
        vratio = vsize.height() / imsize[0]

        if hratio > vratio:  # better horizontal "resolution"
            self.resize(vsize.width() + 200, int(vsize.width() / imratio) + 46)
        else:
            self.resize(int(vsize.height() * imratio) + 200, vsize.height() + 46)
        return

    # =========================================================
    def load_shm(self):
        fname = QFileDialog.getOpenFileName(
            self, 'Load SHM file', '/dev/shm/*.im.shm')[0]

        if fname != '':  # a new shm file was indeed selected
            if self.mySHM is not None:
                self.mySHM.close(erase_file=False)
                self.mySHM = None
        else:  # false alert: don't change anything
            return

        self.array_title.setText(fname)
        self.mySHM = shm(str(fname))
        self.live_counter = -1
        self.naxis = self.mySHM.mtdata['naxis']
        self.myslice = 0
        self.myDRK = np.zeros_like(self.mySHM.get_data())
        self.auto_resize_window()

    # =========================================================
    def load_shm_drk(self):
        fname = QFileDialog.getOpenFileName(
            self, 'Load SHM file for dark', '/dev/shm/*.im.shm')[0]

        if fname != '':  # a new shm file was indeed selected
            try:
                temp = shm(str(fname))
                self.myDRK = temp.get_data()
                temp = None
            except:
                self.myDRK = np.zeros_like(self.mySHM.get_data())
        else:
            print("No new dark was loaded")
            return

    # =========================================================
    def refresh_all(self):
        ''' ----------------------------------------------------------
        Refresh the display
        ---------------------------------------------------------- '''
        self.test = 0
        add_msg = None
        if self.mySHM is not None:
            mycntr = self.mySHM.get_counter()
            if mycntr == 0:  # simulation has been stopped. Start anew!
                self.live_counter = -1
            if self.live_counter < mycntr:
                if self.naxis == 2:
                    self.data_img = self.mySHM.get_data(False, True)
                else:
                    self.data_img = self.mySHM.get_data(False, True)[
                        self.myslice]

                self.data_img = self.data_img.astype(float)
                self.live_counter = self.mySHM.get_counter()

                if self.chB_dark_sub.isChecked():
                    try:
                        self.data_img -= self.myDRK
                    except AttributeError:
                        msg = self.lbl_stats.text()
                        add_msg = "Dark subtraction problem?"
                        self.lbl_stats.setText(msg)
                        pass

            self.refresh_img()
            self.refresh_stats(add_msg=add_msg)


# ==========================================================
# ==========================================================
if __name__ == "__main__":
    main()
