#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import numpy as np
import matplotlib.pyplot as plt
import math

import scipy.constants as const


def first(x):
    return x[0]


def gaussian(x, sigma):
    return (1/(sigma*math.sqrt(2*math.pi)))*np.exp(-(x/sigma)**2/2)


if __name__ == '__main__':
    infilename = sys.argv[1]
    outfilename = sys.argv[2]
    savefig = sys.argv[3]
    sampling = int(sys.argv[4])
    FWHM = float(sys.argv[5])

    g0 = (const.e**2)/const.h

    iets = np.loadtxt(infilename, dtype={
        'names': ('mode_idx', 'wavenumber', 'energy', 'intensity_tot',
                  'intensity_uu', 'intensity_du',
                  'intensity_ud', 'intensity_dd'),
        'formats': ('i8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8')
    })
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    plt.ylabel(r'$\Delta W_{\alpha}$ / $W_\alpha^{tot}$')
    plt.xlabel('Wavenumber [1/cm]')
    ax2 = ax1.twiny()

    result_vec = []
    for idx, intensity in enumerate([iets['intensity_uu'],
                                     iets['intensity_dd']], 1):
        wavenum_sample = np.linspace(min(iets['wavenumber']),
                                     max(iets['wavenumber']),
                                     num=sampling)
        point_distance = (max(iets['wavenumber']) -
                          min(iets['wavenumber']))/sampling
        signal = np.zeros(sampling)
        pair = list(zip(wavenum_sample, signal))
        signal_pair = list(zip(iets['wavenumber'], intensity))
        pair += signal_pair
        pair.sort(key=first)

        p_wavenum = [i[0] for i in pair]
        i_wavenum = [i[1] for i in pair]
        dx = (max(iets['wavenumber'])
              - min(iets['wavenumber']))/len(i_wavenum)

        sigma = FWHM/2*math.sqrt(2*math.log(2, math.e))
        gx = np.arange(-4*sigma, sigma*4, dx)
        gauss = [gaussian(x, sigma) for x in gx]

        result = np.convolve(i_wavenum, gauss, mode='full')
        result_vec.append(result)

        while True:
            p_wavenum = [p_wavenum[0]-point_distance] + p_wavenum
            if len(p_wavenum) == len(result):
                break
            p_wavenum = p_wavenum + [p_wavenum[-1]+point_distance]
            if len(p_wavenum) == len(result):
                break

    wavenum_1cm2mV = 1e5*const.c*const.h/const.e
    ax1.plot(p_wavenum, abs(result_vec[1] - result_vec[0])/g0)
    ax2.plot([min(p_wavenum)*wavenum_1cm2mV,
              max(p_wavenum)*wavenum_1cm2mV],
             2*[0.0], ' ')
    plt.xlabel('Bias Voltage [mV]')
    np.savetxt(outfilename + str(idx), list(zip(p_wavenum, result)),
               # fmt='{%d} {%.18e} {%d} {%.18e}')
               fmt=['%.18e', '%.18e'])
    plt.yaxis.grid()
    plt.savefig(savefig, bbox_inches='tight')
