Source code for scitex_dsp._crop

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "ywatanabe (2024-11-02 22:50:46)"
# File: ./scitex_repo/src/scitex/dsp/_crop.py

import numpy as np


[docs] def crop(sig_2d, window_length, overlap_factor=0.0, axis=-1, time=None): """ Crops the input signal into overlapping windows of a specified length, allowing for an arbitrary axis and considering a time vector. Parameters: - sig_2d (numpy.ndarray): The input sig_2d array to be cropped. Can be multi-dimensional. - window_length (int): The length of each window to crop the sig_2d into. - overlap_factor (float): The fraction of the window that consecutive windows overlap. For example, an overlap_factor of 0.5 means 50% overlap. - axis (int): The time axis along which to crop the sig_2d. - time (numpy.ndarray): The time vector associated with the signal. Its length should match the signal's length along the cropping axis. Returns: - cropped_windows (numpy.ndarray): The cropped signal windows. The shape depends on the input shape and the specified axis. """ # Ensure axis is in a valid range if axis < 0: axis += sig_2d.ndim if axis >= sig_2d.ndim or axis < 0: raise ValueError("Invalid axis. Axis out of range for sig_2d dimensions.") if time is not None: # Validate the length of the time vector against the signal's dimension if sig_2d.shape[axis] != len(time): raise ValueError( "Length of time vector does not match signal's dimension along the specified axis." ) # Move the target axis to the last position axes = np.arange(sig_2d.ndim) axes[axis], axes[-1] = axes[-1], axes[axis] sig_2d_permuted = np.transpose(sig_2d, axes) # Compute the number of windows and the step size seq_len = sig_2d_permuted.shape[-1] step = int(window_length * (1 - overlap_factor)) n_windows = max( 1, ((seq_len - window_length) // step + 1) ) # Ensure at least 1 window # Crop the sig_2d into windows cropped_windows = [] cropped_times = [] for i in range(n_windows): start = i * step end = start + window_length cropped_windows.append(sig_2d_permuted[..., start:end]) if time is not None: cropped_times.append(time[start:end]) # Convert list of windows back to numpy array cropped_windows = np.array(cropped_windows) cropped_times = np.array(cropped_times) # cropped_windows now has shape (n_windows, *permuted_dims_except_last, window_length). # The original `axis` was moved to the last position before windowing; we keep the # window_length at the end and move the n_windows axis to where the original `axis` was. if axis != sig_2d.ndim - 1: # cropped_windows axes: 0=n_windows, 1..ndim-1 = permuted other dims, ndim=window_length. # We want output axes: at position `axis` -> n_windows; remaining permuted dims fill # the other positions (excluding the trailing window_length). n_dims = cropped_windows.ndim # = sig_2d.ndim + 1 # Positions in cropped_windows: 0 (n_windows) and 1..n_dims-2 (other dims), n_dims-1 (window_length). # The "other dims" correspond to permuted axes [0..ndim-2] of sig_2d_permuted, which are # the original axes with `axis` swapped to the end. So permuted_other_axes = axes[:-1]. permuted_other = list( axes[:-1] ) # original-axis labels for cropped_windows[1..n_dims-2] # Build the target ordering of original-axis labels for the output (excluding window_length): # Place n_windows (label = `axis`) at position `axis`, fill rest with permuted_other in order. new_order_labels = [] other_iter = iter(permuted_other) for i in range(sig_2d.ndim): if i == axis: new_order_labels.append(axis) else: new_order_labels.append(next(other_iter)) # Map labels to source positions in cropped_windows (excluding the last window dim): # source position 0 has label `axis` (n_windows); positions 1..n_dims-2 have labels permuted_other. label_to_src = {axis: 0} for src_pos, lbl in enumerate(permuted_other, start=1): label_to_src[lbl] = src_pos perm = [label_to_src[lbl] for lbl in new_order_labels] + [n_dims - 1] cropped_windows = np.transpose(cropped_windows, axes=perm) if time is None: return cropped_windows else: return cropped_windows, cropped_times
def main(): import random FS = 128 N_CHS = 19 RECORD_S = 13 WINDOW_S = 2 FACTOR = 0.5 # To pts record_pts = int(RECORD_S * FS) window_pts = int(WINDOW_S * FS) # Demo signal sig2d = np.random.rand(N_CHS, record_pts) time = np.arange(record_pts) / FS # Main xx, tt = crop(sig2d, window_pts, overlap_factor=FACTOR, time=time) print(f"sig2d.shape: {sig2d.shape}") print(f"xx.shape: {xx.shape}") # Validation i_seg = random.randint(0, len(xx) - 1) start = int(i_seg * window_pts * FACTOR) end = start + window_pts assert np.allclose(sig2d[:, start:end], xx[i_seg]) if __name__ == "__main__": # parser = argparse.ArgumentParser(description='') # import argparse # # Argument Parser import sys import matplotlib.pyplot as plt import scitex # parser.add_argument('--var', '-v', type=int, default=1, help='') # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') # args = parser.parse_args() # Main CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( sys, plt, verbose=False ) main() scitex.session.close(CONFIG, verbose=False, notify=False) # EOF