Source code for scitex_dsp._spectral._pac

#!/usr/bin/env python3
# Time-stamp: "2024-11-26 22:24:40 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/_pac.py

THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/_pac.py"

import sys

import matplotlib.pyplot as plt
import numpy as np

try:
    import torch

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    torch = None


from scitex_decorators import signal_fn

if TORCH_AVAILABLE:
    from scitex_nn._PAC import PAC


def _check_torch():
    if not TORCH_AVAILABLE:
        raise ImportError(
            "PyTorch is not installed. Please install with: pip install torch"
        )


"""
scitex.dsp.pac function
"""


# @batch_fn
[docs] @signal_fn def pac( x, fs, pha_start_hz=2, pha_end_hz=20, pha_n_bands=100, amp_start_hz=60, amp_end_hz=160, amp_n_bands=100, device="cuda", batch_size=1, batch_size_ch=-1, fp16=False, trainable=False, n_perm=None, amp_prob=False, ): """ Compute the phase-amplitude coupling (PAC) for signals. This function automatically handles inputs as PyTorch tensors, NumPy arrays, or pandas DataFrames. Arguments: - x (torch.Tensor | np.ndarray | pd.DataFrame): Input signal. Shape can be either (batch_size, n_chs, seq_len) or - fs (float): Sampling frequency of the input signal. - pha_start_hz (float, optional): Start frequency for phase bands. Default is 2 Hz. - pha_end_hz (float, optional): End frequency for phase bands. Default is 20 Hz. - pha_n_bands (int, optional): Number of phase bands. Default is 100. - amp_start_hz (float, optional): Start frequency for amplitude bands. Default is 60 Hz. - amp_end_hz (float, optional): End frequency for amplitude bands. Default is 160 Hz. - amp_n_bands (int, optional): Number of amplitude bands. Default is 100. Returns: - torch.Tensor: PAC values. Shape: (batch_size, n_chs, pha_n_bands, amp_n_bands) - numpy.ndarray: Phase bands used for the computation. - numpy.ndarray: Amplitude bands used for the computation. Example: FS = 512 T_SEC = 4 xx, tt, fs = scitex.dsp.demo_sig( batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="tensorpac" ) pac, pha_mids_hz, amp_mids_hz = scitex.dsp.pac(xx, fs) """ _check_torch() def process_ch_batching(m, x, batch_size_ch, device): n_chs = x.shape[1] n_batches = (n_chs + batch_size_ch - 1) // batch_size_ch agg = [] for ii in range(n_batches): start, end = batch_size_ch * ii, min(batch_size_ch * (ii + 1), n_chs) _pac = m(x[:, start:end, :].to(device)).detach().cpu() agg.append(_pac) # return np.concatenate(agg, axis=1) return torch.cat(agg, dim=1) m = PAC( x.shape[-1], fs, pha_start_hz=pha_start_hz, pha_end_hz=pha_end_hz, pha_n_bands=pha_n_bands, amp_start_hz=amp_start_hz, amp_end_hz=amp_end_hz, amp_n_bands=amp_n_bands, fp16=fp16, trainable=trainable, n_perm=n_perm, amp_prob=amp_prob, ).to(device) if batch_size_ch == -1: return m(x.to(device)), m.PHA_MIDS_HZ, m.AMP_MIDS_HZ else: return ( process_ch_batching(m, x, batch_size_ch, device), m.PHA_MIDS_HZ, m.AMP_MIDS_HZ, )
if __name__ == "__main__": import matplotlib.pyplot as plt import scitex # Start CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) pac, freqs_pha, freqs_amp = scitex.dsp.pac( np.random.rand(1, 16, 24000), 400, batch_size=1, batch_size_ch=8, fp16=True, n_perm=16, ) # # Parameters # FS = 512 # T_SEC = 4 # # IS_TRAINABLE = False # # FP16 = True # for IS_TRAINABLE in [True, False]: # for FP16 in [True, False]: # # Demo signal # xx, tt, fs = scitex.dsp.demo_sig( # batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac" # ) # # scitex.str.print_debug() # # xx = np.random.rand(1,16,24000) # # fs = 400 # # scitex calculation # pac_scitex, pha_mids_scitex, amp_mids_scitex = scitex.dsp.pac( # xx, # fs, # pha_n_bands=50, # amp_n_bands=30, # trainable=IS_TRAINABLE, # fp16=FP16, # ) # i_batch, i_ch = 0, 0 # pac_scitex = pac_scitex[i_batch, i_ch] # printc(type(pac_scitex)) # # Tensorpac calculation # ( # _, # _, # _pha_mids_tp, # _amp_mids_tp, # pac_tp, # ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, T_SEC) # # Validates the consitency in frequency definitions # assert np.allclose( # pha_mids_scitex, _pha_mids_tp # ) # assert np.allclose( # amp_mids_scitex, _amp_mids_tp # ) # scitex.io.save( # (pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex), # "./data/cache.npz", # ) # # ################################################################################ # # # cache # # pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex = scitex.io.load( # # "./data/cache.npz" # # ) # # ################################################################################ # # Plots # fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac( # pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex # ) # fig.suptitle( # "Phase-Amplitude Coupling calculation\n\n(Bandpass Filtering -> Hilbert Transformation-> Modulation Index)" # ) # plt.show() # scitex.gen.reload(scitex.dsp) # # Saves the figure # trainable_str = "trainable" if IS_TRAINABLE else "static" # fp_str = "fp16" if FP16 else "fp32" # scitex.io.save( # fig, f"pac_with_{trainable_str}_bandpass_{fp_str}.png" # ) # def run_method_tests(): # import scitex # # Test parameters # FS = 512 # T_SEC = 4 # class PACProcessor: # @batch_torch_fn # def process_pac(self, x, fs, **kwargs): # return pac(x, fs, **kwargs) # @signal_fn # def process_signal(self, x): # return x * 2 # def run_method_basic_tests(): # processor = PACProcessor() # # Generate test signal # xx, tt, fs = scitex.dsp.demo_sig( # batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac" # ) # try: # # Test method with batch processing # result_batch, pha_mids, amp_mids = processor.process_pac( # xx, fs, pha_n_bands=50, amp_n_bands=30, batch_size=1 # ) # assert torch.is_tensor(result_batch) # # Test basic torch method # result_torch = processor.process_signal(xx) # assert torch.is_tensor(result_torch) # scitex.str.printc("Passed: Basic method tests", "yellow") # except Exception as err: # scitex.str.printc(f"Failed: Basic method tests - {str(err)}", "red") # def run_method_cuda_tests(): # if not torch.cuda.is_available(): # scitex.str.printc( # "CUDA method tests skipped: No GPU available", "yellow" # ) # return # processor = PACProcessor() # xx, tt, fs = scitex.dsp.demo_sig( # batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac" # ) # try: # # Test with CUDA # result_cuda, _, _ = processor.process_pac(xx, fs, device="cuda") # assert result_cuda.device.type == "cuda" # result_torch = processor.process_signal(xx, device="cuda") # assert result_torch.device.type == "cuda" # scitex.str.printc("Passed: CUDA method tests", "yellow") # except Exception as err: # scitex.str.printc(f"Failed: CUDA method tests - {str(err)}", "red") # def run_method_batch_size_tests(): # processor = PACProcessor() # batch_sizes = [1, 2, 4] # for batch_size in batch_sizes: # try: # xx, tt, fs = scitex.dsp.demo_sig( # batch_size=batch_size, # n_chs=1, # fs=FS, # t_sec=T_SEC, # sig_type="pac", # ) # result, _, _ = processor.process_pac( # xx, fs, batch_size=batch_size # ) # assert result.shape[0] == batch_size # scitex.str.printc( # f"Passed: Method batch size test with size={batch_size}", # "yellow", # ) # except Exception as err: # scitex.str.printc( # f"Failed: Method batch size test with size={batch_size} - {str(err)}", # "red", # ) # # Execute method test suites # test_suites = [ # ("Method Basic Tests", run_method_basic_tests), # ("Method CUDA Tests", run_method_cuda_tests), # ("Method Batch Size Tests", run_method_batch_size_tests), # ] # for test_name, test_func in test_suites: # test_func() # if __name__ == "__main__": # run_method_tests() # # EOF # """ # python -m scitex.dsp._pac # """ # # EOF