Coverage for contextualized/dags/losses.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1import numpy as np 

2import torch 

3from contextualized.dags.graph_utils import dag_pred, dag_pred_with_factors 

4 

5 

6def dag_loss_dagma_indiv(w, s=1): 

7 M = s * torch.eye(w.shape[-1]) - w * w 

8 return w.shape[-1] * np.log(s) - torch.slogdet(M)[1] 

9 

10 

11def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs): 

12 """DAG loss on batched networks W using the 

13 DAGMA log-determinant 

14 """ 

15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W]) 

16 return alpha * torch.mean(sample_losses) 

17 

18 

19def dag_loss_poly_indiv(w): 

20 d = w.shape[-1] 

21 return torch.trace(torch.eye(d) + (1 / d) * torch.matmul(w, w)) - d 

22 

23 

24def dag_loss_poly(W, **kwargs): 

25 """DAG loss on batched networks W using the 

26 h_poly form: h_poly(W) = Tr((I + 1/d(W*W)^d) - d 

27 """ 

28 return torch.mean(torch.Tensor([dag_loss_poly_indiv(w) for w in W])) 

29 

30 

31def dag_loss_notears(W, alpha=0.0, rho=0.0, **kwargs): 

32 """ 

33 DAG loss on batched networks W using the 

34 NOTEARS matrix exponential trace 

35 """ 

36 m = torch.linalg.matrix_exp(W * W) 

37 h = m.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) - W.shape[-1] 

38 return torch.mean(alpha * h + 0.5 * rho * h * h) 

39 

40 

41# Lasso (L1) regularization term 

42l1_loss = lambda w, l1: l1 * torch.sum(torch.abs(w)) 

43 

44 

45# Mean squared error of y_true vs. y_pred 

46mse_loss = lambda y_true, y_pred: ((y_true - y_pred) ** 2).mean() 

47 

48 

49def linear_sem_loss_with_factors(x_true, w_pred, factor_mat): 

50 """Computes MSE loss between true x and predicted linear structural equation model, 

51 for torch tensors. Works on batches only. 

52 

53 Args: 

54 x_true (torch.FloatTensor): Vector of True features x 

55 w_pred (torch.FloatTensor): Predicted linear structural equation model 

56 factor_mat (torch.FloatTensor): Factor matrix, size latent factors x x.shape[-1] 

57 

58 Returns: 

59 torch.tensor: MSE loss for data features and predicted SEM. 

60 """ 

61 x_prime = dag_pred_with_factors(x_true, w_pred, factor_mat) 

62 return (0.5 / x_true.shape[0]) * torch.square(torch.linalg.norm(x_true - x_prime)) 

63 

64 

65def linear_sem_loss(x_true, w_pred): 

66 """Computes MSE loss between true x and predicted linear structural equation model, 

67 for torch tensors. Works on batches only. 

68 

69 Args: 

70 x_true (torch.FloatTensor): Vector of True features x 

71 w_pred (torch.FloatTensor): Predicted linear structural equation model 

72 

73 Returns: 

74 torch.tensor: MSE loss for data features and predicted SEM. 

75 """ 

76 x_prime = dag_pred(x_true, w_pred) 

77 return (0.5 / x_true.shape[0]) * torch.square(torch.linalg.norm(x_true - x_prime))