Soft DTW
Since there are various implementations online of the SoftDTW, this module was created as a wrapper to be able to easily switch between different implementations.
Currently, the following implementations are available :
- pytorch-softdtw-cuda by Mehran Maghoumi [Mag20] [MTL21]
GitHub repository : https://github.com/Maghoumi/pytorch-softdtw-cuda
- pysdtw by Antoine Loriette
GitHub repository : https://github.com/toinsson/pysdtw
PyPi Page : https://pypi.org/project/pysdtw/
- sdtw-cuda-torch by BGU-CS-VIL (implemented by Ron Shapira Weber) [WF26] [SWBL+25]
GitHub repository : https://github.com/BGU-CS-VIL/sdtw-cuda-torch
If you use this module, please cite together with this package the original paper of the implementation you are using.
Example
Mehran Maghoumi’s (mag) implementation
>>> import torch
>>> from dtw_loss_functions import soft_dtw
>>> use_cuda = torch.cuda.is_available()
>>> sdtw_loss = soft_dtw.soft_dtw(implementation = 'mag', sdtw_config = {'use_cuda' : use_cuda, 'gamma' : 0.1})
>>> device = 'cuda' if use_cuda else 'cpu'
>>> batch_size = 5
>>> time_samples = 300
>>> channels = 1
>>> x = torch.randn(batch_size, time_samples, channels).to(device)
>>> x_r = torch.randn(batch_size, time_samples, channels).to(device)
>>> output_sdtw = sdtw_loss(x, x_r)
Ron Shapira Weber’s (ron) implementation
>>> import torch
>>> from dtw_loss_functions import soft_dtw
>>> sdtw_loss = soft_dtw.soft_dtw(implementation = 'ron', sdtw_config = {'gamma' : 0.1, 'dist' : 'sqeuclidean'})
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> batch_size = 5
>>> time_samples = 300
>>> channels = 1
>>> x = torch.randn(batch_size, time_samples, channels).to(device)
>>> x_r = torch.randn(batch_size, time_samples, channels).to(device)
>>> output_sdtw = sdtw_loss(x, x_r)
- class dtw_loss_functions.soft_dtw.soft_dtw(implementation: str = 'mag', sdtw_config: dict = {})[source]
Bases:
ModuleSoftDTW class. This class is a wrapper for the different implementations of the SoftDTW. The implementation can be selected by passing the
implementationargument to the constructor. The available implementations are:mag: pytorch-softdtw-cuda by Mehran Maghoumi
pysdtw: pysdtw by Antoine Loriette
ron : sdtw-cuda-torch by BGU-CS-VIL (implemented by Ron Shapira Weber)
Together with the implementation, a configuration dictionary can be passed to the constructor to set the parameters of the SDTW function. The available parameters depend on the implementation selected, but some parameters are common for all implementations (e.g.
gamma,normalize,bandwidth). The available parameters are listed below.- Parameters:
implementation (str) –
Implementation to use for the SDTW.
sdtw_config (dict, optional) –
Configuration dictionary for the SDTW function. Note that if a parameter is not specified in the configuration dictionary, the default value will be used (the default values are specified in the description of each parameter). The dictionary can contain the following keys :
- use_cudabool
If
True, this class will use the CUDA implementation of the SDTW. Only for themagandpysdtwimplementations. Default isFalse.- gammafloat, optional
Value of the gamma hyperparameter for the SDTW. Default is
1.- normalizebool, optional
If
True, the SDTW divergence will be computed instead of the SDTW. Default isFalse.- bandwidthfloat, optional
Sakoe-Chiba bandwidth for pruning. If the
Noneis given, no pruning is applied. Default isNone.- dist_funcfunction, optional
Only for the
magandpysdtwimplementations. If passed, this function will be used as distance function to use for the SDTW. IfNone, the default distance function of the implementation will be used (squared Euclidean distance for both implementations). Default isNone.- diststr, optional
Only for the
ronimplementation. It has the same purpose asdist_funcfor the other implementations, but in this case must be a string.- fusedbool, optional
Only for the
ronimplementation.None-> auto (use fused only when possible)True-> require fused (error if not possible)False-> never fused (always materialize D and use D-based autograd)
Methods
check_implementation(implementation)Check if the selected implementation is valid.
create_sdtw_function(sdtw_config)Create and return an istance of the SDTW function based on the current implementation and parameters.
forward(x, y)Compute the SoftDTW distance between two time series.
set_implementation(implementation[, ...])Set the implementation to use for the SDTW.
set_sdtw_config([sdtw_config, ...])Set the configuration for the SDTW function.
- check_implementation(implementation: str)[source]
Check if the selected implementation is valid. If not, raise an error.
- create_sdtw_function(sdtw_config: dict)[source]
Create and return an istance of the SDTW function based on the current implementation and parameters.
- Parameters:
sdtw_config (dict) – Configuration dictionary for the SDTW function. The keys of the dictionary are the same as the parameters of the constructor.
- forward(x: Tensor, y: Tensor) Tensor[source]
Compute the SoftDTW distance between two time series.
- Parameters:
x (torch.Tensor) – First input tensor of shape B x T x C
y (torch.Tensor) – Second input tensor of shape B x T x C
- Returns:
SoftDTW distance between the two time series
- Return type:
torch.Tensor
- set_implementation(implementation: str, reset_sdtw_function: bool = True)[source]
Set the implementation to use for the SDTW. This function can be used to change the implementation after the class has been initialized.
- set_sdtw_config(sdtw_config: dict = {}, reset_sdtw_function: bool = True)[source]
Set the configuration for the SDTW function. This function can be used to change the configuration after the class has been initialized.
- Parameters:
sdtw_config (dict) – Configuration dictionary for the SDTW function. The keys of the dictionary are the same as the parameters of the constructor. Note that if a key is absent from the dictionary, the default value for that parameter will be used (the default values are specified in the description of each parameter).
reset_sdtw_function (bool, optional) – If
True, the SDTW function will be reset with the new configuration. IfFalse, the SDTW function will not be reset, but the new configuration will be saved as an attribute of the class. Default isTrue.