Source code for dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_ron.autograd_xy

from __future__ import annotations

import torch
from torch.autograd import Function

from .cuda.launcher import (
    softdtw_forward_cuda_fused_sqeuclid,
    softdtw_backward_cuda_fused_sqeuclid,   # returns E (exp(logE))
)


[docs] class SoftDTWXYAutograd(Function):
[docs] @staticmethod def forward(ctx, X: torch.Tensor, Y: torch.Tensor, gamma: float, bandwidth: float | None): # Forward CUDA fused: returns out (B,) and R (B,N+2,M+2) out, R = softdtw_forward_cuda_fused_sqeuclid(X, Y, float(gamma), -1.0 if bandwidth is None else float(bandwidth)) # Save X,Y for gradient math; save detached R (no graph needed) ctx.save_for_backward(X, Y, R.detach()) ctx.gamma = float(gamma) # Normalize bandwidth semantics: <=0 means disabled if bandwidth is None: ctx.bandwidth = -1.0 else: bw = float(bandwidth) ctx.bandwidth = -1.0 if bw <= 0 else bw return out
[docs] @staticmethod def backward(ctx, grad_output: torch.Tensor): X, Y, R = ctx.saved_tensors gamma = ctx.gamma bw = ctx.bandwidth # Compute E via fused log-space backward (Numba). Pass detached X/Y to be safe. E = softdtw_backward_cuda_fused_sqeuclid(X.detach(), Y.detach(), R, gamma, bw) # (B,N,M) # Scale by upstream grad (B,) -> (B,1,1) g = grad_output.reshape(-1).to(device=X.device, dtype=X.dtype).view(-1, 1, 1) E = E * g # Reductions for sqeuclidean chain rule EX = E.sum(dim=2) # (B,N) EY = E.sum(dim=1) # (B,M) grad_X = 2.0 * (X * EX.unsqueeze(2) - torch.bmm(E, Y)) # (B,N,D) grad_Y = 2.0 * (Y * EY.unsqueeze(2) - torch.bmm(E.transpose(1, 2), X)) # (B,M,D) return grad_X, grad_Y, None, None