Source code for dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_ron.distances

from __future__ import annotations

import torch


[docs] def sqeuclidean(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Efficient squared Euclidean distance: D[b,i,j] = ||x[b,i]-y[b,j]||^2 x: (B,N,D), y: (B,M,D) returns: (B,N,M) """ if x.dim() != 3 or y.dim() != 3: raise ValueError(f"Expected x,y as (B,N,D)/(B,M,D). Got {tuple(x.shape)} and {tuple(y.shape)}") if x.shape[0] != y.shape[0] or x.shape[2] != y.shape[2]: raise ValueError(f"Batch/features mismatch: {tuple(x.shape)} vs {tuple(y.shape)}") # (B,N) x2 = (x * x).sum(dim=-1) # (B,M) y2 = (y * y).sum(dim=-1) # (B,N,M) xy = torch.bmm(x, y.transpose(1, 2)) D = x2.unsqueeze(2) + y2.unsqueeze(1) - 2.0 * xy # Numerical cleanup (fp roundoff can produce tiny negatives) return D.clamp_min(0.0)
[docs] def pairwise_distance(x: torch.Tensor, y: torch.Tensor, *, dist: str) -> torch.Tensor: dist = dist.lower() if dist in ("sqeuclidean", "sq_euclidean", "squared_euclidean"): return sqeuclidean(x, y) raise ValueError(f"Unknown dist='{dist}'. Supported: 'sqeuclidean'")