"""
Implementation of OTW, presented in :cite:`otw_paper`
Authors
-------
Alberto Zancanaro <alberto.zancanaro@uni.lu>
"""
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Imports
import torch
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
[docs]
class otw(torch.nn.Module) :
def __init__(self, m : float = 1, s : int | float = 0.5, beta : float = 1, reduction : str = 'mean') -> None :
"""
Initializes the OTW distance module.
"""
self.m = m
self.s = s
self.beta = beta
self.reduction = reduction
[docs]
def forward(self, x : torch.Tensor, y : torch.Tensor) -> torch.Tensor :
"""
Computes the OTW distance between two time series.
Parameters
----------
x : torch.Tensor
First time series, of shape ``B x L`` where ``B`` is the batch size and ``L`` is the length of the time series.
y : torch.Tensor
Second time series, of shape ``B x L`` where ``B`` is the batch size and ``L`` is the length of the time series.
Returns
-------
torch.Tensor
OTW distance between the two time series
"""
return otw_distance(x, y, self.m, self.s, self.beta, self.reduction)
[docs]
def otw_distance(x : torch.Tensor, y : torch.Tensor, m : float = 1, s : int | float = 0.5, beta : float = 1, reduction : str = 'mean') -> torch.Tensor:
"""
Implements the OTW distance between two time series, as defined in equations (9) of the paper.
Parameters
----------
x : torch.Tensor
First time series, of shape ``B x L`` where ``B`` is the batch size and ``L`` is the length of the time series.
y : torch.Tensor
Second time series, of shape ``B x L`` where ``B`` is the batch size and ``L`` is the length of the time series.
m : float
Waste cost parameter, default is ``1``.
s : int | float
Window size parameter, it can be an integer or a float between ``0`` and ``1``. Default is ``0.5``.
If float, it is interpreted as a fraction of the length of the time series.
If integer, it is interpreted as the number of time steps.
beta : float
Hyperparameter for the smooth l1 loss, default is ``1``.
Returns
-------
torch.Tensor
OTW distance between the two time series
"""
# Set temporary to high value to stabilize training
# beta = 100
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Check inputs
if x.dim() != y.dim() :
raise ValueError(f"Input time series must have the same number of dimensions. Current dimensions: x {x.shape}, y {y.shape}")
if x.dim() > 2 or y.dim() > 2 :
raise ValueError(f"Input time series must be 2-dimensional (B, L). Current dimensions: x {x.shape}, y {y.shape}")
elif x.dim() == 1 and y.dim() == 1 :
# Handle the case of single time series (no batch dimension)
x = x.unsqueeze(0)
y = y.unsqueeze(0)
if s <= 0 :
raise ValueError(f"Window size parameter s must be positive. Current value: {s}")
if 0 < s < 1 :
s = int(s * x.size(1))
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Compute OTW distance
otw_term_1 = m * smooth_l1_loss(window_cumsum(x - y, s), beta, reduction = reduction)
otw_term_2 = 0
for i in range(x.shape[1] - 1) :
otw_term_2 += smooth_l1_loss(window_cumsum(x[:, 0:(i + 1)] - y[:, 0:(i + 1)], s), beta, reduction = reduction)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
return otw_term_1 + otw_term_2
[docs]
def smooth_l1_loss(x : torch.Tensor, beta : float, reduction = 'mean') -> torch.Tensor :
"""
Computes the smooth l1 of the input tensor x, as defined in equation (9) of the paper.
Parameters
----------
x : torch.Tensor
Input tensor of shape ``B``. Each element corresponds to the difference between two time series.
beta : float
Hyperparameter for the smooth l1 loss.
reduction : str
Specifies the reduction to apply to the output: ``none`` | ``mean`` | ``sum``. Default: ``mean``.
Returns
-------
torch.Tensor
Smooth l1 loss of the input tensor. If reduction is ``none``, the output has the same shape as `x`. If reduction is ``mean`` or ``sum``, the output is a scalar.
"""
# Compute smooth l1 loss element-wise
loss = torch.where(torch.abs(x) < beta, 0.5 * x ** 2 / beta, torch.abs(x) - 0.5 * beta)
# Apply reduction
if reduction == 'mean' :
loss = loss.mean()
elif reduction == 'sum' :
loss = loss.sum()
return loss
[docs]
def window_cumsum(x : torch.Tensor, s : int) -> torch.Tensor :
"""
Computes the cumulative sum of the input tensor x as defined in equation (7) of the paper.
Given a time series A represented as an array of values ``[a1, a2, ..., aL]``, the window cumsum is computed as :
``window_cumsum(A) = cumsum(A) - cumsum(A[0:L-s])``
(i.e. the cumsum of all the array minus the cumsum of the array excluding the last s elements)
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, L).
s : int
Window size.
Returns
-------
torch.Tensor
Cumulative sum over the sliding window, of shape (B).
"""
cumsum_x = torch.cumsum(x, dim = 1)
windowed_cumsum = cumsum_x[:, -1] - (cumsum_x[:, - (s + 1)] if s < x.shape[1] else 0)
return windowed_cumsum