Block DTW
Implementation of the block DTW, which is a variant of the SDTW that computes the SDTW on blocks of the signal instead of the entire signal.
Example
>>> from dtw_loss_functions import block_dtw
>>> import torch
>>> block_size = 25
>>> use_cuda = torch.cuda.is_available()
>>> block_dtw_loss = block_dtw.block_dtw(block_size, sdtw_config = {'use_cuda' : use_cuda})
>>> batch_size = 5
>>> time_samples = 300
>>> channels = 1
>>> device = 'cuda' if use_cuda else 'cpu'
>>> x = torch.randn(batch_size, time_samples, channels).to(device)
>>> x_r = torch.randn(batch_size, time_samples, channels).to(device)
>>> output_block_dtw = block_dtw_loss(x, x_r)
- class dtw_loss_functions.block_dtw.block_dtw(block_size: int, implementation: str = 'mag', sdtw_config: dict = {})[source]
Bases:
ModuleClass that compute the block DTW loss, which is a variant of the SDTW that computes the SDTW on blocks of the signal instead of the entire signal. The block DTW can be computed in two ways:
Naive (SEQUENTIAL) implementation: compute the SDTW on each block separately and sum the results.
Optimized (PARALLEL) implementation: exploit reshaping of the input tensors to compute the SDTW on all blocks at once.
This class will select automatically which implementation to use based on the input tensors length and block size. See the docstring of block_dtw_optimized for more details on the requirements for the optimized implementation.
If you are not sure which implementation to use, you can use the block_dtw class. Note that if you know a priori that the optimized version can be used in your case, it is recommended to use directly the block_dtw_optimized class, which is faster than the block_dtw class (no overhead of checking the input tensors length and block size).
- block_size
Size of the blocks into which to divide the signal.
- Type:
int
- block_dtw_naive
Instance of the naive implementation of the block DTW.
- Type:
- block_dtw_optimized
Instance of the optimized implementation of the block DTW.
- Type:
- Parameters:
block_size (int) – Size of the blocks into which to divide the signal.
implementation (str) – Implementation to use for the SDTW. This parameter is passed to the SDTW implementation used in the block DTW. See the docstring of the
soft_dtwclass for more details on the available implementations and their parameters.sdtw_config (dict, optional) – Configuration dictionary for the SDTW function used in the block DTW. See the docstring of the
soft_dtwclass for more details on the available parameters and their default values. If a parameter is not specified, or if the dictionary is empty, the default values will be used for all the parameters.
Methods
forward(x, x_r)Compute the block DTW loss between the input tensors
xandx_r.- forward(x: tensor, x_r: tensor) tensor[source]
Compute the block DTW loss between the input tensors
xandx_r.- Parameters:
x (torch.tensor) – First input tensor of shape
B x T x Cx_r (torch.tensor) – Second input tensor of shape
B x T x C
- Returns:
recon_error – Tensor of shape
Bcontaining the block DTW loss for each sample in the batch.- Return type:
torch.tensor
- class dtw_loss_functions.block_dtw.block_dtw_naive(block_size: int, implementation: str = 'mag', sdtw_config: dict = {})[source]
Bases:
soft_dtwNaive implementation of the block DTW, which computes the SDTW on each block separately.
For details on the parameters, see the docstring of the
block_dtwclass.Methods
forward(x, x_r)Compute the block DTW loss between the input tensors
xandx_rby computing the SDTW on each block separately and summing the results.- forward(x: tensor, x_r: tensor) tensor[source]
Compute the block DTW loss between the input tensors
xandx_rby computing the SDTW on each block separately and summing the results.- Parameters:
x (torch.tensor) – First input tensor of shape
B x T x Cx_r (torch.tensor) – Second input tensor of shape
B x T x C
- Returns:
recon_error – Tensor of shape
Bcontaining the block DTW loss for each sample in the batch.- Return type:
torch.tensor
- class dtw_loss_functions.block_dtw.block_dtw_optimized(block_size: int, implementation: str = 'mag', sdtw_config: dict = {})[source]
Bases:
soft_dtwOptimized implementation of the block DTW, which exploits reshaping of the input tensors to compute the SDTW on all blocks at once.
This version can be used only if the length of the input tensors is divisible by the block size, i.e. if
length_signal % block_size == 0. Note that the class will not check if this condition is satisfied, so it is the responsibility of the user to ensure that the input tensors length and block size are compatible.This requirement is necessary because the optimized implementation exploits reshaping of the input tensors to compute the SDTW on all blocks at once. This works because SoftDTW implementation allow batched inputs,so we can reshape the input tensors to have a new batch size equal to the number of blocks, and compute the SDTW on all blocks at once.
For details on the parameters, see the docstring of the
block_dtwclass.Methods
forward(x, x_r)Compute the block DTW loss between the input tensors x and x_r by exploiting reshaping of the input tensors to compute the SDTW on all blocks at once.
- forward(x: tensor, x_r: tensor) tensor[source]
Compute the block DTW loss between the input tensors x and x_r by exploiting reshaping of the input tensors to compute the SDTW on all blocks at once.
- Parameters:
x (torch.tensor) – First input tensor of shape B x T x C
x_r (torch.tensor) – Second input tensor of shape B x T x C
- Returns:
recon_error – Tensor of shape B containing the block DTW loss for each sample in the batch.
- Return type:
torch.tensor