Source code for dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_ron.cuda.launcher

from __future__ import annotations

import math
import numpy as np
import torch
from numba import cuda
from numba import jit, prange

from .kernels import softdtw_forward_kernel, softdtw_forward_diag_cuda
from .kernels import softdtw_backward_log_cuda, softdtw_backward_log_diag_cuda
from .kernels import softdtw_forward_diag_sqeuclid_cuda
from .kernels import softdtw_backward_log_diag_sqeuclid_cuda

# GLOBALS
TPB_LONG = 256


# HELPERS
def _diag_bounds(p: int, N: int, M: int) -> tuple[int, int]:
    i_min = max(0, p - (M - 1))
    i_max = min(N - 1, p)
    return i_min, i_max



def _threads_and_passes(N: int, M: int) -> tuple[int, int]:
    tpb = max(N, M)
    n_passes = 2 * tpb - 1
    return tpb, n_passes


# MAIN - on-the-fly D
[docs] def softdtw_forward_cuda_fused_sqeuclid(X: torch.Tensor, Y: torch.Tensor, gamma: float, bandwidth: float): """ Fused SoftDTW forward for squared-euclidean distance that does NOT materialize D (B,N,M). X: (B,N,D), Y: (B,M,D) CUDA tensors Returns: (out: (B,), R: (B,N+2,M+2)) """ if not (X.is_cuda and Y.is_cuda): raise ValueError("Expected CUDA tensors X and Y") if X.dim() != 3 or Y.dim() != 3: raise ValueError(f"Expected X,Y as (B,N,D)/(B,M,D). Got {tuple(X.shape)} and {tuple(Y.shape)}") if X.shape[0] != Y.shape[0] or X.shape[2] != Y.shape[2]: raise ValueError(f"Batch/features mismatch: {tuple(X.shape)} vs {tuple(Y.shape)}") # Detach before passing to numba X_ = X.detach().contiguous() Y_ = Y.detach().contiguous() B, N, D = X_.shape M = Y_.shape[1] # Allocate DP table R = torch.full((B, N + 2, M + 2), math.inf, device=X_.device, dtype=X_.dtype) R[:, 0, 0] = 0.0 X_ca = cuda.as_cuda_array(X_) Y_ca = cuda.as_cuda_array(Y_) R_ca = cuda.as_cuda_array(R) inv_bw = float(bandwidth) # can be -1.0 to disable # Anti-diagonals over unpadded (i,j): p = i + j, i in [0,N-1], j in [0,M-1] for p in range(N + M - 1): i_min = max(0, p - (M - 1)) i_max = min(N - 1, p) if i_max < i_min: continue diag_len = i_max - i_min + 1 grid_x = (diag_len + TPB_LONG - 1) // TPB_LONG # grid=(grid_x, B), so batch = blockIdx.y in kernel softdtw_forward_diag_sqeuclid_cuda[(grid_x, B), TPB_LONG]( X_ca, Y_ca, R_ca, float(gamma), inv_bw, N, M, D, p, ) out = R[:, -2, -2].contiguous() return out, R
[docs] def softdtw_backward_cuda_fused_sqeuclid(X: torch.Tensor, Y: torch.Tensor, R: torch.Tensor, gamma: float, bandwidth: float): """ Fused SoftDTW backward (log-space) for squared-euclidean distance that does NOT materialize D_pad. Inputs: X: (B,N,D) CUDA Y: (B,M,D) CUDA R: (B,N+2,M+2) CUDA (from forward) Returns: E: (B,N,M) CUDA (E = d SoftDTW / d D in linear space, via exp(logE)) """ if not (X.is_cuda and Y.is_cuda and R.is_cuda): raise ValueError("Expected CUDA tensors X, Y, R") if X.dim() != 3 or Y.dim() != 3: raise ValueError(f"Expected X,Y as (B,N,D)/(B,M,D). Got {tuple(X.shape)} and {tuple(Y.shape)}") if X.shape[0] != Y.shape[0] or X.shape[2] != Y.shape[2]: raise ValueError(f"Batch/features mismatch: {tuple(X.shape)} vs {tuple(Y.shape)}") # Detach before passing to numba X_ = X.detach().contiguous() Y_ = Y.detach().contiguous() B, N, D = X_.shape M = Y_.shape[1] if R.shape != (B, N + 2, M + 2): raise ValueError(f"Expected R shape {(B, N+2, M+2)}, got {tuple(R.shape)}") R_ = R.contiguous() # ---------- boundary conditions for R ---------- R_work = R_.clone() R_work[:, :, -1] = -math.inf R_work[:, -1, :] = -math.inf R_work[:, -1, -1] = R_work[:, -2, -2] # ---------- init logE ---------- logE = torch.full((B, N + 2, M + 2), -math.inf, device=X_.device, dtype=X_.dtype) logE[:, -1, -1] = 0.0 # log(1) X_ca = cuda.as_cuda_array(X_) Y_ca = cuda.as_cuda_array(Y_) Rw_ca = cuda.as_cuda_array(R_work) logE_ca = cuda.as_cuda_array(logE) inv_gamma = float(1.0 / gamma) bw = float(bandwidth) # Reverse anti-diagonals over unpadded indices p = i + j, starting from (N-1)+(M-1)-1 = N+M-2 down to 0 for p in range(N + M - 2, -1, -1): i_min = max(0, p - (M - 1)) i_max = min(N - 1, p) if i_max < i_min: continue diag_len = i_max - i_min + 1 grid_x = (diag_len + TPB_LONG - 1) // TPB_LONG softdtw_backward_log_diag_sqeuclid_cuda[(grid_x, B), TPB_LONG]( X_ca, Y_ca, Rw_ca, logE_ca, inv_gamma, bw, N, M, D, p, ) # crop + exp E = torch.exp(logE[:, 1:N + 1, 1:M + 1]).contiguous() return E
# MAIN - Full D Matrix
[docs] def softdtw_forward_cuda(D: torch.Tensor, gamma: float, bandwidth: float): if not D.is_cuda: raise ValueError("Expected CUDA tensor D") D_ = D.detach().contiguous() B, N, M = D_.shape if gamma <= 0: raise ValueError(f"gamma must be > 0, got {gamma}") # Allocate DP table R = torch.full((B, N + 2, M + 2), math.inf, device=D_.device, dtype=D_.dtype) R[:, 0, 0] = 0.0 # --- Fast path: one block per batch element --- tpb, n_passes = _threads_and_passes(N, M) USE_FAST_PATH = (tpb <= 1024) if USE_FAST_PATH: softdtw_forward_kernel[B, tpb]( cuda.as_cuda_array(D_), float(gamma), float(bandwidth), N, M, n_passes, cuda.as_cuda_array(R), ) out = R[:, -2, -2].contiguous() return out, R # --- Long sequence path: tiled anti-diagonal launches --- D_ca = cuda.as_cuda_array(D_) R_ca = cuda.as_cuda_array(R) # Iterate anti-diagonals in unpadded (i,j) coords over D (shape N x M) for p in range(N + M - 1): i_min, i_max = _diag_bounds(p, N, M) if i_max < i_min: continue diag_len = i_max - i_min + 1 grid_x = (diag_len + TPB_LONG - 1) // TPB_LONG # grid=(grid_x, B) so batch index is blockIdx.y inside kernel softdtw_forward_diag_cuda[(grid_x, B), TPB_LONG]( D_ca, R_ca, float(gamma), float(bandwidth), N, M, p, ) out = R[:, -2, -2].contiguous() return out, R
[docs] def softdtw_backward_cuda_log(D: torch.Tensor, R: torch.Tensor, gamma: float, bandwidth: float): if not D.is_cuda: raise ValueError("Expected CUDA tensor D") D_ = D.detach().contiguous() B, N, M = D_.shape R = R.contiguous() if gamma <= 0: raise ValueError(f"gamma must be > 0, got {gamma}") # ---------- pad D ---------- D_pad = torch.zeros((B, N + 2, M + 2), device=D_.device, dtype=D_.dtype) D_pad[:, 1:N + 1, 1:M + 1] = D_ # ---------- boundary conditions for R ---------- R_work = R.clone() R_work[:, :, -1] = -math.inf R_work[:, -1, :] = -math.inf R_work[:, -1, -1] = R_work[:, -2, -2] # ---------- init logE ---------- logE = torch.full((B, N + 2, M + 2), -math.inf, device=D_.device, dtype=D_.dtype) logE[:, -1, -1] = 0.0 # log(1) # ---------- choose fast vs tiled ---------- tpb, n_passes = _threads_and_passes(N, M) USE_FAST_PATH = (tpb <= 1024) if USE_FAST_PATH: # fast path: your existing diagonal backward kernel (single block per batch) softdtw_backward_log_cuda[B, tpb]( cuda.as_cuda_array(D_pad), cuda.as_cuda_array(R_work), float(1.0 / gamma), float(bandwidth), N, M, n_passes, cuda.as_cuda_array(logE), ) else: # tiled path: launch one kernel per anti-diagonal in reverse order Dp_ca = cuda.as_cuda_array(D_pad) Rw_ca = cuda.as_cuda_array(R_work) logE_ca = cuda.as_cuda_array(logE) inv_gamma = float(1.0 / gamma) bw = float(bandwidth) if bw <= 0: bw = -1.0 # unpadded indices (i,j) are 0..N-1, 0..M-1, diagonals p = i+j for p in range(N + M - 2, -1, -1): i_min, i_max = _diag_bounds(p, N, M) if i_max < i_min: continue diag_len = i_max - i_min + 1 grid_x = (diag_len + TPB_LONG - 1) // TPB_LONG softdtw_backward_log_diag_cuda[(grid_x, B), TPB_LONG]( Dp_ca, Rw_ca, logE_ca, inv_gamma, bw, N, M, p, ) # crop + exp E = torch.exp(logE[:, 1:N + 1, 1:M + 1]).contiguous() return E
# ---- CPU reference (optional but useful for tests) ---- @jit(nopython=True, parallel=True) def _softdtw_forward_cpu_np(D: np.ndarray, gamma: float, bandwidth: float): B, N, M = D.shape R = np.ones((B, N + 2, M + 2), dtype=D.dtype) * np.inf R[:, 0, 0] = 0.0 for b in prange(B): for j in range(1, M + 1): for i in range(1, N + 1): if 0 < bandwidth < abs(i - j): continue r0 = -R[b, i - 1, j - 1] / gamma r1 = -R[b, i - 1, j] / gamma r2 = -R[b, i, j - 1] / gamma rmax = max(max(r0, r1), r2) rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) softmin = -gamma * (np.log(rsum) + rmax) R[b, i, j] = D[b, i - 1, j - 1] + softmin return R @jit(nopython=True, parallel=True) def _softdtw_backward_cpu_np(D_: np.ndarray, R: np.ndarray, gamma: float, bandwidth: float): B, N, M = D_.shape D = np.zeros((B, N + 2, M + 2), dtype=D_.dtype) D[:, 1:N + 1, 1:M + 1] = D_ E = np.zeros((B, N + 2, M + 2), dtype=D_.dtype) E[:, -1, -1] = 1.0 R[:, :, -1] = -np.inf R[:, -1, :] = -np.inf R[:, -1, -1] = R[:, -2, -2] for b in prange(B): for j in range(M, 0, -1): for i in range(N, 0, -1): if np.isinf(R[b, i, j]): R[b, i, j] = -np.inf if 0 < bandwidth < abs(i - j): continue a0 = (R[b, i + 1, j] - R[b, i, j] - D[b, i + 1, j]) / gamma b0 = (R[b, i, j + 1] - R[b, i, j] - D[b, i, j + 1]) / gamma c0 = (R[b, i + 1, j + 1] - R[b, i, j] - D[b, i + 1, j + 1]) / gamma a = np.exp(a0); bb = np.exp(b0); c = np.exp(c0) E[b, i, j] = E[b, i + 1, j] * a + E[b, i, j + 1] * bb + E[b, i + 1, j + 1] * c return E[:, 1:N + 1, 1:M + 1]
[docs] def softdtw_forward_cpu(D: torch.Tensor, gamma: float, bandwidth: float): D_np = D.detach().cpu().numpy() R_np = _softdtw_forward_cpu_np(D_np, float(gamma), float(bandwidth)) R = torch.from_numpy(R_np).to(D.device).type_as(D) out = R[:, -2, -2].contiguous() return out, R
[docs] def softdtw_backward_cpu(D: torch.Tensor, R: torch.Tensor, gamma: float, bandwidth: float): D_np = D.detach().cpu().numpy() R_np = R.detach().cpu().numpy().copy() # .copy() prevents in-place mutation of saved autograd tensor E_np = _softdtw_backward_cpu_np(D_np, R_np, float(gamma), float(bandwidth)) return torch.from_numpy(E_np).to(D.device).type_as(D).contiguous()