Source code for scitex_dsp.utils._zero_pad

#!/usr/bin/env python3
# Time-stamp: "2024-11-26 10:30:34 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/utils/_zero_pad.py

THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/_zero_pad.py"

import numpy as np

try:
    import torch
    import torch.nn.functional as F

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    torch = None
    F = None


def _check_torch():
    if not TORCH_AVAILABLE:
        raise ImportError(
            "PyTorch is not installed. Please install with: pip install torch"
        )


def _zero_pad_1d(x, target_length):
    """Zero pad a 1D tensor to target length."""
    _check_torch()
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x)
    padding_needed = target_length - len(x)
    padding_left = padding_needed // 2
    padding_right = padding_needed - padding_left
    return F.pad(x, (padding_left, padding_right), "constant", 0)


[docs] def zero_pad(xs, dim=0): """Zero pad a list of arrays to the same length. Args: xs: List of tensors or arrays dim: Dimension to stack along Returns: Stacked tensor with zero padding """ # Convert to tensors if needed tensors = [] for x in xs: if isinstance(x, np.ndarray): tensors.append(torch.tensor(x)) elif isinstance(x, torch.Tensor): tensors.append(x) else: tensors.append(torch.tensor(x)) max_len = max([len(x) for x in tensors]) return torch.stack([_zero_pad_1d(x, max_len) for x in tensors], dim=dim)
# EOF