Source code for scitex_dsp.utils._ensure_3d
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-05 01:04:03 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/utils/_ensure_3d.py
from scitex_decorators import torch_fn
[docs]
@torch_fn
def ensure_3d(x):
if x.ndim == 1: # assumes (seq_len,)
x = x.unsqueeze(0).unsqueeze(0)
elif x.ndim == 2: # assumes (batch_siize, seq_len)
x = x.unsqueeze(1)
return x
# EOF