#!/usr/bin/env python

import matplotlib.pyplot as plt
import numpy as np

import rapidtide.correlate as tide_corr
import rapidtide.filter as tide_filt
import rapidtide.fit as tide_fit


def makeashiftedvector(Fs, arrlen, overallamp, noiseamp, timeshift):
    freq1 = 0.15 * 2.0 * np.pi
    freq2 = 0.11 * 2.0 * np.pi
    freq3 = 0.05 * 2.0 * np.pi
    amp1 = 1.0
    amp2 = 0.5
    amp3 = 0.25
    phase1 = 0.0
    phase2 = np.pi / 2.0
    phase3 = np.pi

    xaxis = (arrlen / Fs) * np.linspace(0.0, 1.0, arrlen, endpoint=False)
    thevector = overallamp * (
        amp1 * np.sin(freq1 * (xaxis + timeshift) + phase1)
        + amp2 * np.sin(freq2 * (xaxis + timeshift) + phase2)
        + amp3 * np.sin(freq3 * (xaxis + timeshift) + phase3)
    ) + noiseamp * np.random.rand(arrlen)
    return thevector, xaxis


thefreq = 30.0
fixedvec, xaxis = makeashiftedvector(thefreq, int(thefreq * 400), 0.5, 0.01, 0.0)
movingvec, xaxis = makeashiftedvector(thefreq, int(thefreq * 400), 0.5, 0.01, 10.0)

alignedvec, maxdelay, maxval, failreason = tide_corr.aligntcwithref(
    fixedvec,
    movingvec,
    thefreq,
    display=True,
    verbose=True,
)
print(f"{maxdelay=}, {maxval=}, {failreason=}")

filteredvec, datatoremove, R, coffs, dummy = tide_fit.linfitfilt(fixedvec, alignedvec, debug=True)

offset = 0.0
plt.plot(xaxis, fixedvec + offset)
offset += 1.0
plt.plot(xaxis, movingvec + offset)
offset += 1.0
plt.plot(xaxis, alignedvec + offset)
offset += 1.0
plt.plot(xaxis, datatoremove + offset)
offset += 1.0
plt.plot(xaxis, filteredvec + offset)
print(R)
plt.show()
