#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import numpy as np
import matplotlib.pyplot as plt
import math

import scipy.constants as const


def first(x):
    return x[0]


def gaussian(x, sigma):
    return (1/(sigma*math.sqrt(2*math.pi)))*np.exp(-(x/sigma)**2/2)


if __name__ == '__main__':
    infilename = sys.argv[1]
    outfilename = sys.argv[2]
    savefig = sys.argv[3]
    sampling = int(sys.argv[4])
    FWHM = float(sys.argv[5])

    g0 = (const.e**2)/const.h

    with open(infilename, 'r') as fp:
        line = fp.readline()
    ncols = len(line.split())

    if ncols == 4:
        iets = np.loadtxt(infilename, dtype={
            'names': ('mode_idx', 'wavenumber', 'energy', 'intensity'),
            'formats': ('i8', 'f8', 'f8', 'f8')
        })

        wavenum_sample = np.linspace(min(iets['wavenumber']),
                                     max(iets['wavenumber']),
                                     num=sampling)
        point_distance = (max(iets['wavenumber']) -
                          min(iets['wavenumber']))/sampling
        signal = np.zeros(sampling)
        pair = list(zip(wavenum_sample, signal))
        signal_pair = list(zip(iets['wavenumber'], iets['intensity']))
        pair += signal_pair
        pair.sort(key=first)

        p_wavenum = [i[0] for i in pair]
        i_wavenum = [i[1] for i in pair]
        dx = (max(iets['wavenumber']) - min(iets['wavenumber']))/len(i_wavenum)

        sigma = FWHM/2*math.sqrt(2*math.log(2, math.e))
        gx = np.arange(-4*sigma, sigma*4, dx)
        gauss = [gaussian(x, sigma) for x in gx]

        result = np.convolve(i_wavenum, gauss, mode='full')

        while True:
            p_wavenum = [p_wavenum[0]-point_distance] + p_wavenum
            if len(p_wavenum) == len(result):
                break
            p_wavenum = p_wavenum + [p_wavenum[-1]+point_distance]
            if len(p_wavenum) == len(result):
                break

        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        plt.ylabel(r'$W_{\alpha}$ / $g_0$ [$\Omega^{-1}$]')
        plt.xlabel('Wavenumber [1/cm]')
        ax2 = ax1.twiny()
        wavenum_1cm2mV = 1e5*const.c*const.h/const.e
        ax1.plot(p_wavenum, result/g0)
        ax2.plot([min(p_wavenum)*wavenum_1cm2mV,
                  max(p_wavenum)*wavenum_1cm2mV],
                 2*[0.0], ' ')
        plt.xlabel('Bias Voltage [mV]')
        plt.savefig(savefig, bbox_inches='tight')
        np.savetxt(outfilename, list(zip(p_wavenum, result)),
                   # fmt='{%d} {%.18e} {%d} {%.18e}')
                   fmt=['%.18e', '%.18e'])

    if ncols == 8:
        iets = np.loadtxt(infilename, dtype={
            'names': ('mode_idx', 'wavenumber', 'energy', 'intensity_tot',
                      'intensity_uu', 'intensity_du',
                      'intensity_ud', 'intensity_dd'),
            'formats': ('i8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8')
        })
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        plt.ylabel(r'$W_{\alpha}$ / $g_0$ [$\Omega^{-1}$]')
        plt.xlabel('Wavenumber [1/cm]')
        ax2 = ax1.twiny()

        for idx, intensity in enumerate([iets['intensity_uu'],
                                         iets['intensity_dd']], 1):
            wavenum_sample = np.linspace(min(iets['wavenumber']),
                                         max(iets['wavenumber']),
                                         num=sampling)
            point_distance = (max(iets['wavenumber']) -
                              min(iets['wavenumber']))/sampling
            signal = np.zeros(sampling)
            pair = list(zip(wavenum_sample, signal))
            signal_pair = list(zip(iets['wavenumber'], intensity))
            pair += signal_pair
            pair.sort(key=first)

            p_wavenum = [i[0] for i in pair]
            i_wavenum = [i[1] for i in pair]
            dx = (max(iets['wavenumber'])
                  - min(iets['wavenumber']))/len(i_wavenum)

            sigma = FWHM/2*math.sqrt(2*math.log(2, math.e))
            gx = np.arange(-4*sigma, sigma*4, dx)
            gauss = [gaussian(x, sigma) for x in gx]

            result = np.convolve(i_wavenum, gauss, mode='full')

            while True:
                p_wavenum = [p_wavenum[0]-point_distance] + p_wavenum
                if len(p_wavenum) == len(result):
                    break
                p_wavenum = p_wavenum + [p_wavenum[-1]+point_distance]
                if len(p_wavenum) == len(result):
                    break

            wavenum_1cm2mV = 1e5*const.c*const.h/const.e
            ax1.plot(p_wavenum, result/g0)
            ax2.plot([min(p_wavenum)*wavenum_1cm2mV,
                      max(p_wavenum)*wavenum_1cm2mV],
                     2*[0.0], ' ')
            plt.xlabel('Bias Voltage [mV]')
            np.savetxt(outfilename + str(idx), list(zip(p_wavenum, result)),
                       # fmt='{%d} {%.18e} {%d} {%.18e}')
                       fmt=['%.18e', '%.18e'])
        plt.yaxis.grid()
        plt.savefig(savefig, bbox_inches='tight')
