Source code for scitex_dsp._spectral._wavelet

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-04 02:12:00 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/_wavelet.py

"""scitex.dsp.wavelet function"""

from scitex_decorators import batch_fn, signal_fn
from scitex_nn._Wavelet import Wavelet


# Functions
[docs] @signal_fn @batch_fn def wavelet( x, fs, freq_scale="linear", out_scale="linear", device="cuda", batch_size=32, ): m = Wavelet(fs, freq_scale=freq_scale, out_scale="linear").to(device).eval() pha, amp, freqs = m(x.to(device)) if out_scale == "log": amp = (amp + 1e-5).log() if amp.isnan().any(): print("NaN is detected while taking the lograrithm of amplitude.") return pha, amp, freqs
# @signal_fn # def wavelet( # x, # fs, # freq_scale="linear", # out_scale="linear", # device="cuda", # batch_size=32, # ): # @signal_fn # def _wavelet( # x, # fs, # freq_scale="linear", # out_scale="linear", # device="cuda", # ): # m = ( # Wavelet(fs, freq_scale=freq_scale, out_scale=out_scale) # .to(device) # .eval() # ) # pha, amp, freqs = m(x.to(device)) # if out_scale == "log": # amp = (amp + 1e-5).log() # if amp.isnan().any(): # print( # "NaN is detected while taking the lograrithm of amplitude." # ) # return pha, amp, freqs # if len(x) <= batch_size: # try: # pha, amp, freqs = _wavelet( # x, # fs, # freq_scale=freq_scale, # out_scale=out_scale, # device=device, # ) # torch.cuda.empty_cache() # return pha, amp, freqs # except Exception as e: # print(e) # print("\nTrying Batch Mode...") # n_batches = (len(x) + batch_size - 1) // batch_size # device_orig = x.device # pha, amp, freqs = [], [], [] # for i_batch in tqdm(range(n_batches)): # start = i_batch * batch_size # end = (i_batch + 1) * batch_size # _pha, _amp, _freqs = _wavelet( # x[start:end], # fs, # freq_scale=freq_scale, # out_scale=out_scale, # device=device, # ) # torch.cuda.empty_cache() # # to CPU # pha.append(_pha.cpu()) # amp.append(_amp.cpu()) # freqs.append(_freqs.cpu()) # pha = torch.vstack(pha) # amp = torch.vstack(amp) # freqs = freqs[0] # try: # pha = pha.to(device_orig) # amp = amp.to(device_orig) # freqs = freqs.to(device_orig) # except Exception as e: # print( # f"\nError occurred while transferring wavelet outputs back to the original device. Proceeding with CPU tensor. \n\n({e})" # ) # sleep(0.5) # torch.cuda.empty_cache() # return pha, amp, freqs if __name__ == "__main__": import sys import matplotlib.pyplot as plt import numpy as np import scitex # demo-only umbrella usage # Start CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt, agg=True) # Parameters FS = 512 SIG_TYPE = "chirp" T_SEC = 4 # Demo signal xx, tt, fs = scitex.dsp.demo_sig( batch_size=64, n_chs=19, n_segments=2, t_sec=T_SEC, fs=FS, sig_type=SIG_TYPE, ) if SIG_TYPE in ["tensorpac", "pac"]: i_segment = 0 xx = xx[:, :, i_segment, :] # Main pha, amp, freqs = wavelet(xx, fs, device="cuda") freqs = freqs[0, 0] # Plots i_batch, i_ch = 0, 0 fig, axes = scitex.plt.subplots(nrows=3) # # Time vector for x-axis extents # time_extent = [tt.min(), tt.max()] # Trace axes[0].plot(tt, xx[i_batch, i_ch], label=SIG_TYPE) axes[0].set_ylabel("Amplitude [?V]") axes[0].legend(loc="upper left") axes[0].set_title("Signal") # Amplitude # extent = [time_extent[0], time_extent[1], freqs.min(), freqs.max()] axes[1].imshow2d( np.log(amp[i_batch, i_ch] + 1e-5).T, cbar_label="Log(amplitude [?V]) [a.u.]", aspect="auto", # extent=extent, # origin="lower", ) axes[1] = scitex.plt.ax.set_ticks(axes[1], x_ticks=tt, y_ticks=freqs) axes[1].set_ylabel("Frequency [Hz]") axes[1].set_title("Amplitude") # Phase axes[2].imshow2d( pha[i_batch, i_ch].T, cbar_label="Phase [rad]", aspect="auto", # extent=extent, # origin="lower", ) axes[2] = scitex.plt.ax.set_ticks(axes[2], x_ticks=tt, y_ticks=freqs) axes[2].set_ylabel("Frequency [Hz]") axes[2].set_title("Phase") fig.suptitle("Wavelet Transformation") fig.supxlabel("Time [s]") for ax in axes: ax = scitex.plt.ax.set_n_ticks(ax) # ax.set_xlim(time_extent[0], time_extent[1]) fig.tight_layout(rect=[0, 0.03, 1, 0.95]) scitex.io.save(fig, "wavelet.png") # Close scitex.session.close(CONFIG) # EOF """ /home/ywatanabe/proj/entrance/scitex/dsp/_wavelet.py """ # EOF