Source code for scitex_dsp._spectral._modulation_index

#!/usr/bin/env python3
# Time-stamp: "2024-11-04 02:09:55 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/_modulation_index.py

try:
    import torch

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    torch = None

from scitex_decorators import signal_fn

if TORCH_AVAILABLE:
    from scitex_nn._ModulationIndex import ModulationIndex


[docs] @signal_fn def modulation_index(pha, amp, n_bins=18, amp_prob=False): """ pha: (batch_size, n_chs, n_freqs_pha, n_segments, seq_len) amp: (batch_size, n_chs, n_freqs_amp, n_segments, seq_len) """ if not TORCH_AVAILABLE: raise ImportError( "PyTorch is not installed. Please install with: pip install torch" ) return ModulationIndex(n_bins=n_bins, amp_prob=amp_prob)(pha, amp)
def _reshape(x, batch_size=2, n_chs=4): return ( torch.tensor(x) .float() .unsqueeze(0) .unsqueeze(0) .repeat(batch_size, n_chs, 1, 1, 1) ) if __name__ == "__main__": import sys import matplotlib.pyplot as plt import scitex # Start CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( sys, plt, fig_scale=3 ) # Parameters FS = 512 T_SEC = 5 # Demo signal xx, tt, fs = scitex.dsp.demo_sig(fs=FS, t_sec=T_SEC, sig_type="tensorpac") # xx.shape: (8, 19, 20, 512) # Tensorpac ( pha, amp, freqs_pha, freqs_amp, pac_tp, ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, t_sec=T_SEC) # GPU calculation with scitex.dsp.nn.ModulationIndex pha, amp = _reshape(pha), _reshape(amp) pac_scitex = scitex.dsp.modulation_index(pha, amp).cpu().numpy() i_batch, i_ch = 0, 0 pac_scitex = pac_scitex[i_batch, i_ch] # Plots fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac( pac_scitex, pac_tp, freqs_pha, freqs_amp ) fig.suptitle("MI (modulation index) calculation") scitex.io.save(fig, "modulation_index.png") # Close scitex.session.close(CONFIG) # EOF """ /home/ywatanabe/proj/entrance/scitex/dsp/_modulation_index.py """ # EOF