SoftDTW (Mehran Maghoumi)
CUDA implementation of SoftDTW by Mehran Maghoumi. This is one of the most popular and used SoftDTW application available on web.
The code in this file is identical to the one in the original repository. The only changes are the modification of the docstring to numpydoc style, to make it consistent with the rest of the codebase.
If you use this implementation please cite [Mag20] [MTL21]
GitHub repository : https://github.com/Maghoumi/pytorch-softdtw-cuda
Example
>>> import torch
>>> from dtw_loss_functions import soft_dtw
>>> use_cuda = torch.cuda.is_available()
>>> device = 'cuda' if use_cuda else 'cpu'
>>> batch_size = 5
>>> time_samples = 300
>>> channels = 1
>>> x = torch.randn(batch_size, time_samples, channels).to(device)
>>> x_r = torch.randn(batch_size, time_samples, channels).to(device)
>>> sdtw_config = {use_cuda : use_cuda}
>>> sdtw_loss = soft_dtw.soft_dtw(implementation = 'mag', sdtw_config = sdtw_config)
>>> output_sdtw = sdtw_loss(x, x_r)
- class dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_mag.SoftDTW(use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None)[source]
Bases:
ModuleThe soft DTW implementation that optionally supports CUDA
- normalize
Flag indicating whether to perform normalization (as discussed in discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790) Note that if normalize is set to True, the SoftDTW divergence will be computed, which is defined as
SoftDTW(X, Y) - 1/2 * (SoftDTW(X, X) + SoftDTW(Y, Y)).- Type:
bool
- gamma
SoftDTW’s gamma parameter
- Type:
float
- bandwidth
Sakoe-Chiba bandwidth for pruning. Passing ‘None’ will disable pruning.
- Type:
float or None
- dist_func
Optional point-wise distance function to use. If ‘None’, then a default Euclidean distance function will be used.
- Type:
function or None
- Parameters:
use_cuda (bool) – Flag indicating whether the CUDA implementation should be used
gamma (float) – SoftDTW’s gamma parameter
normalize (bool) – Flag indicating whether to perform normalization (as discussed in discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790) Note that if normalize is set to True, the SoftDTW divergence will be computed, which is defined as
SoftDTW(X, Y) - 1/2 * (SoftDTW(X, X) + SoftDTW(Y, Y)).bandwidth (float or None) – Sakoe-Chiba bandwidth for pruning. Passing ‘None’ will disable pruning.
dist_func (function or None) – Optional point-wise distance function to use. If ‘None’, then a default Euclidean distance function will be used.
Methods
forward(X, Y)Compute the soft-DTW value between X and Y :param X: One batch of examples, batch_size x seq_len x dims :param Y: The other batch of examples, batch_size x seq_len x dims :return: The computed results
- dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_mag.compute_softdtw(D, gamma, bandwidth)[source]
- dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_mag.compute_softdtw_backward(D_, R, gamma, bandwidth)[source]
- dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_mag.compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E)[source]
- dtw_loss_functions.soft_dtw_implementations.soft_dtw_cuda_mag.compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R)[source]
Computes the soft-DTW value between two sequences using CUDA.