Coverage for contextualized/regression/losses.py: 100%

7 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1""" 

2Losses used in regression. 

3""" 

4 

5import torch 

6 

7 

8def MSE(Y_true, Y_pred): 

9 """ 

10 Returns 

11 - MSE (scalar torch.tensor): the mean squared-error or L2-error 

12 of multivariate and univariate regression problems. Default 

13 loss for contextualized.regression models. 

14 

15 MV/UV: Multivariate/Univariate 

16 MT/ST: Multi-task/Single-task 

17 

18 MV ST: beta (y_dim, x_dim), mu (y_dim, 1), x (y_dim, x_dim), y (y_dim, 1) 

19 MV MT: beta (x_dim,), mu (1,), x (x_dim,), y (1,) 

20 UV ST: beta (y_dim, x_dim, 1), mu (y_dim, x_dim, 1), x (y_dim, x_dim, 1), y (y_dim, x_dim, 1) 

21 UV MT: beta (1,), mu (1,), x (1,), y (1,) 

22 """ 

23 residual = Y_true - Y_pred 

24 return residual.pow(2).mean() 

25 

26 

27def BCELoss(Y_true, Y_pred): 

28 loss = -( 

29 Y_true * torch.log(Y_pred + 1e-8) + (1 - Y_true) * torch.log(1 - Y_pred + 1e-8) 

30 ) 

31 return loss.mean()