#!/usr/bin/env python

import matplotlib.pyplot as plt
import numpy as np

import rapidtide.correlate as tide_corr
import rapidtide.filter as tide_filt
import rapidtide.fit as tide_fit

arrlen = 200
xaxis = 2.0 * np.pi * np.linspace(0.0, 1.0, arrlen, endpoint=False)

Fs = 1.0
freq1 = 1.0
freq2 = 3.4
freq3 = 7.0
amp1 = 1.0
amp2 = 0.5
amp3 = 0.25
intercept = 0.1
phase1 = 0.0
phase2 = np.pi / 2.0
phase3 = np.pi
noiseamp = 0.5


evs = np.zeros((arrlen, 3), dtype=float)

evs[:, 0] = np.sin(freq1 * xaxis + phase1)
evs[:, 1] = np.sin(freq2 * xaxis + phase2)
evs[:, 2] = np.sin(freq3 * xaxis + phase3)

datavec = amp1 * evs[:, 0] + amp2 * evs[:, 1] + amp3 * evs[:, 2] + intercept
noisevec = noiseamp * np.random.rand(arrlen)
invec = datavec + noisevec

filtered, datatoremove, R, outcoffs, outintercept = tide_fit.linfitfilt(
    invec, evs, debug=True
)

incoffs = [amp1, amp2, amp3]
coffdiffs = outcoffs - incoffs
interceptdiffs = outintercept - intercept
print(f"{incoffs=}, {outcoffs=}, {intercept=}, {outintercept=}, {coffdiffs=}, {interceptdiffs=}")

offset = 0.0
plt.plot(evs[:, 0])
offset += 1.0
plt.plot(evs[:, 1] + offset)
offset += 1.0
plt.plot(evs[:, 2] + offset)
offset += 1.0
plt.plot(invec + offset)
offset += 1.0
plt.plot(datatoremove + offset)
offset += 1.0
plt.plot(filtered + offset)
offset += 1.0
plt.plot(datavec - datatoremove + offset)
print(R)
plt.show()

filtered, datatoremove, R, coffs, dummy = tide_fit.linfitfilt(invec, evs[:, 0], debug=True)

"""
offset = 0.0
plt.plot(evs[:, 0])
offset += 1.0
plt.plot(invec + offset)
offset += 1.0
plt.plot(datatoremove + offset)
offset += 1.0
plt.plot(filtered + offset)
print(R)
plt.show()
"""

invec2 = (
    amp1 * evs[:, 0]
    + amp1 * 0.5 * evs[:, 0] * evs[:, 0]
    + amp1 * 0.25 * evs[:, 0] * evs[:, 0] * evs[:, 0]
    + noiseamp * np.random.rand(arrlen)
)

filtered, datatoremove, R, coffs, dummy = tide_fit.linfitfilt(invec2, evs, debug=True)

"""
offset = 0.0
plt.plot(evs[:, 0])
offset += 1.0
plt.plot(invec2 + offset)
offset += 1.0
plt.plot(datatoremove + offset)
offset += 1.0
plt.plot(filtered + offset)
print(R)
plt.show()
"""

filtered, thenewevs, datatoremove, R, coffs = tide_fit.expandedlinfitfilt(
    invec2, evs, ncomps=3, debug=True
)

"""
offset = 0.0
plt.plot(evs[:, 0])
offset += 1.0
plt.plot(evs[:, 0] * evs[:, 0])
offset += 1.0
plt.plot(evs[:, 0] * evs[:, 0] * evs[:, 0])
offset += 1.0
plt.plot(invec2 + offset)
offset += 1.0
plt.plot(datatoremove + offset)
offset += 1.0
plt.plot(filtered + offset)
print(R)
plt.show()
"""

filtered, datatoremove, R, coffs, dummy = tide_fit.linfitfilt(invec2, thenewevs, debug=True)
"""
offset = 0.0
plt.plot(invec2 + offset)
offset += 1.0
plt.plot(datatoremove + offset)
offset += 1.0
plt.plot(filtered + offset)
print(R)
plt.show()
"""

invec3 = amp1 * np.sin(freq1 * xaxis + phase1 + np.pi / 10.0) + noiseamp * np.random.rand(arrlen)

alignedev, maxdelay, maxval, failreason = tide_corr.aligntcwithref(
    invec3,
    evs[:, 0],
    Fs,
    display=True,
    verbose=True,
)
print(f"{maxdelay=}, {maxval=}, {failreason=}")
filtered, datatoremove, R, coffs, dummy = tide_fit.linfitfilt(invec3, alignedev, debug=True)
offset = 0.0
plt.plot(evs[:, 0])
offset += 1.0
plt.plot(alignedev)
offset += 1.0
plt.plot(invec3 + offset)
offset += 1.0
plt.plot(datatoremove + offset)
offset += 1.0
plt.plot(filtered + offset)
print(R)
plt.show()
