#!/usr/bin/env python

import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sma
import statsmodels.tsa.stattools as smts
from scipy import signal
from statsmodels.tsa.ar_model import AutoReg

import rapidtide.io as tide_io

debug = True
insamplerate, instarttime, incolumns, indata, incompressed, incolsource, inextrainfo = tide_io.readbidstsv(
            "../dst/sub-RAPIDTIDETEST_desc-oversampledmovingregressor_timeseries.json",
            neednotexist=False,
            debug=debug,
        )
print(f"{indata.shape=}")
nlags = 15

p = 2
d = 0
q = 0

for i in range(indata.shape[0]):
    sigma_v, arcoefs, pacf, sigma, phi = smts.levinson_durbin(indata[i, :], nlags=nlags, isacov=False)

    print(f"{sigma_v=}")
    print(f"{arcoefs=}")
    print(f"{pacf=}")
    print(f"{sigma=}")
    print(f"{phi=}")

    plt.plot(arcoefs)
    plt.show()

    plt.plot(pacf)
    plt.show()

    plt.plot(sigma)
    plt.show()
 
    ar_model = AutoReg(indata[i, :], lags=nlags)
    ar_fit = ar_model.fit()
    ar_params = ar_fit.params

    # The prewhitening filter coefficients are 1 for the numerator and 
    # (1, -ar_params[1]) for the denominator
    b = [1]
    a = np.insert(-ar_params[1:], 0, 1)

    # 3. Apply the filter to prewhiten the signal
    prewhitened_signal = signal.lfilter(b, a, indata[i, :])
    plt.plot(indata[i, :] / np.std(indata[i, :]))
    plt.plot(prewhitened_signal / np.std(prewhitened_signal))
    plt.show()
