Source code for dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_ron.utils.checks

from __future__ import annotations

import torch


[docs] def check_D(D: torch.Tensor) -> None: if not isinstance(D, torch.Tensor): raise TypeError("D must be a torch.Tensor") if D.dim() != 3: raise ValueError(f"D must have shape (B,N,M). Got {tuple(D.shape)}") if D.dtype not in (torch.float32, torch.float64, torch.float16, torch.bfloat16): raise TypeError(f"Unsupported dtype {D.dtype}") if not D.is_contiguous(): # We'll make contiguous in launcher; still warn early if you want strictness pass