Source code for scitex_dsp.utils._ensure_3d

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-05 01:04:03 (ywatanabe)"
# File: ./scitex_repo/src/scitex/dsp/utils/_ensure_3d.py

from scitex_decorators import torch_fn


[docs] @torch_fn def ensure_3d(x): if x.ndim == 1: # assumes (seq_len,) x = x.unsqueeze(0).unsqueeze(0) elif x.ndim == 2: # assumes (batch_siize, seq_len) x = x.unsqueeze(1) return x
# EOF