Source code for dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_ron.barycenters

"""
SoftDTW Barycenter Averaging

Implements time series averaging using soft Dynamic Time Warping geometry.
Based on the method from Cuturi & Blondel (ICML 2017).

Reference: https://github.com/tslearn-team/tslearn/blob/main/tslearn/barycenters/softdtw.py
"""

from __future__ import annotations

import time
import torch
from .module import SoftDTW


[docs] def softdtw_barycenter( X: torch.Tensor, *, gamma: float = 1.0, weights: torch.Tensor | None = None, max_iter: int = 100, lr: float = 0.1, init: torch.Tensor | None = None, device: str | torch.device | None = None, verbose: bool = False, fused: bool | None = None, early_stopping: bool = True, patience: int = 10, tol: float = 1e-5, ) -> torch.Tensor: """ Compute a SoftDTW barycenter (time series average) through optimization. This function finds the barycenter that minimizes the weighted sum of SoftDTW distances to all input time series using gradient-based optimization. Args: X: Input time series of shape (B, N, D) where: - B: batch size (number of sequences) - N: sequence length - D: feature dimension gamma: SoftDTW regularization parameter. Default: 1.0 weights: Optional weights for each sequence, shape (B,). Default: uniform max_iter: Maximum optimization iterations. Default: 100 lr: Learning rate for optimization. Default: 0.1 init: Initial barycenter, shape (N, D). If None, uses weighted mean. Default: None device: Device to compute on. If None, uses X's device. Default: None verbose: Print iteration progress and timing. Default: False fused: Fused mode selection. Default: None (auto-select) - None: Auto-select (use fused if CUDA available) - True: Require fused mode (error if not available) - False: Never use fused mode (always use standard distance matrix) early_stopping: Stop early if loss plateaus. Default: True patience: Iterations without improvement before stopping. Default: 10 tol: Absolute improvement threshold for early stopping. Default: 1e-5 Note: Uses absolute improvement (best_loss - loss_val > tol), which handles negative SoftDTW values correctly Returns: Barycenter of shape (N, D) """ # Example: # >>> import torch # >>> from . import barycenter # >>> X = torch.randn(16, 100, 3, device="cuda") # 16 sequences of length 100, dim 3 # >>> barycenter = softdtw_barycenter(X, gamma=1.0, max_iter=50, verbose=True) # >>> barycenter.shape # torch.Size([100, 3]) # # >>> # Force fused mode for memory efficiency # >>> barycenter_fused = softdtw_barycenter(X, fused=True) # # >>> # Force unfused mode for predictable performance # >>> barycenter_unfused = softdtw_barycenter(X, fused=False) device = device or X.device # Move X to target device first X = X.to(device) B, N, D = X.shape # Normalize weights if weights is None: weights = torch.ones(B, device=device) / B else: weights = weights.to(device) weights = weights / weights.sum() # Initialize barycenter with weighted mean (better than unweighted mean) if init is None: barycenter = (X * weights.view(B, 1, 1)).sum(dim=0).clone() else: barycenter = init.clone().to(device) # Ensure barycenter requires gradients barycenter = barycenter.requires_grad_(True) # Create SoftDTW loss function loss_fn = SoftDTW(gamma=gamma, normalize=False, fused=fused) # Optimizer optimizer = torch.optim.Adam([barycenter], lr=lr) # Learning rate scheduler: cosine annealing for better convergence scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max_iter, eta_min=lr * 0.1 ) # Start timing opt_start_time = time.time() best_loss = float('inf') patience_counter = 0 # Optimization loop for iteration in range(max_iter): # Synchronize before timing for accurate CUDA measurements if torch.cuda.is_available(): torch.cuda.synchronize() optimizer.zero_grad() # Expand barycenter to batch size for comparison barycenter_batch = barycenter.unsqueeze(0).expand(B, -1, -1) # Compute SoftDTW loss to all sequences distances = loss_fn(barycenter_batch, X) # shape: (B,) # Weighted loss (can be negative due to SoftDTW soft-min aggregation) loss = (weights * distances).sum() # Backprop and optimization step loss.backward() # Gradient clipping for stability torch.nn.utils.clip_grad_norm_([barycenter], max_norm=1.0) optimizer.step() scheduler.step() loss_val = loss.item() improvement = float('nan') # Track improvement for logging # Early stopping: track absolute improvement (works for negative losses) if early_stopping: improvement = best_loss - loss_val if improvement > tol: # Absolute improvement (works for negative losses) best_loss = loss_val patience_counter = 0 else: patience_counter += 1 # Stop if no improvement for 'patience' iterations (after warmup) if patience_counter >= patience and iteration > max_iter // 2: if verbose: print( f"Early stopping at iteration {iteration + 1} " f"(no improvement for {patience} iterations)" ) break # Optional: Print progress with timing if verbose and (iteration + 1) % 20 == 0: if torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start_time lr_val = optimizer.param_groups[0]['lr'] # Format improvement string (may be NaN if early_stopping disabled) if early_stopping: improvement_str = f"{improvement:9.6f}" else: improvement_str = " N/A " print( f"Iteration {iteration + 1:3d}/{max_iter} | " f"Loss: {loss_val:9.6f} | " f"Improvement: {improvement_str} | " f"LR: {lr_val:.2e}" ) return barycenter.detach()
[docs] def softdtw_barycenter_cpu( X: torch.Tensor, *, gamma: float = 1.0, weights: torch.Tensor | None = None, max_iter: int = 100, lr: float = 0.1, init: torch.Tensor | None = None, verbose: bool = False, fused: bool | None = None, early_stopping: bool = True, patience: int = 10, tol: float = 1e-5, ) -> torch.Tensor: """ Compute a SoftDTW barycenter on CPU (convenience wrapper). Args: X: Input time series of shape (B, N, D) gamma: SoftDTW regularization parameter. Default: 1.0 weights: Optional weights for each sequence. Default: uniform max_iter: Maximum optimization iterations. Default: 100 lr: Learning rate for optimization. Default: 0.01 init: Initial barycenter. If None, uses weighted mean. Default: None verbose: Print iteration progress and timing. Default: False fused: Fused mode selection. Default: None (auto-select) early_stopping: Stop early if loss plateaus. Default: True patience: Iterations without improvement before stopping. Default: 10 tol: Improvement threshold for early stopping. Default: 1e-5 Returns: Barycenter of shape (N, D) """ return softdtw_barycenter( X, gamma=gamma, weights=weights, max_iter=max_iter, lr=lr, init=init, device="cpu", verbose=verbose, fused=fused, early_stopping=early_stopping, patience=patience, tol=tol, )