Coverage for contextualized/dags/losses.py: 100%
26 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1import numpy as np
2import torch
3from contextualized.dags.graph_utils import dag_pred, dag_pred_with_factors
6def dag_loss_dagma_indiv(w, s=1):
7 M = s * torch.eye(w.shape[-1]) - w * w
8 return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]
11def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):
12 """DAG loss on batched networks W using the
13 DAGMA log-determinant
14 """
15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
16 return alpha * torch.mean(sample_losses)
19def dag_loss_poly_indiv(w):
20 d = w.shape[-1]
21 return torch.trace(torch.eye(d) + (1 / d) * torch.matmul(w, w)) - d
24def dag_loss_poly(W, **kwargs):
25 """DAG loss on batched networks W using the
26 h_poly form: h_poly(W) = Tr((I + 1/d(W*W)^d) - d
27 """
28 return torch.mean(torch.Tensor([dag_loss_poly_indiv(w) for w in W]))
31def dag_loss_notears(W, alpha=0.0, rho=0.0, **kwargs):
32 """
33 DAG loss on batched networks W using the
34 NOTEARS matrix exponential trace
35 """
36 m = torch.linalg.matrix_exp(W * W)
37 h = m.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) - W.shape[-1]
38 return torch.mean(alpha * h + 0.5 * rho * h * h)
41# Lasso (L1) regularization term
42l1_loss = lambda w, l1: l1 * torch.sum(torch.abs(w))
45# Mean squared error of y_true vs. y_pred
46mse_loss = lambda y_true, y_pred: ((y_true - y_pred) ** 2).mean()
49def linear_sem_loss_with_factors(x_true, w_pred, factor_mat):
50 """Computes MSE loss between true x and predicted linear structural equation model,
51 for torch tensors. Works on batches only.
53 Args:
54 x_true (torch.FloatTensor): Vector of True features x
55 w_pred (torch.FloatTensor): Predicted linear structural equation model
56 factor_mat (torch.FloatTensor): Factor matrix, size latent factors x x.shape[-1]
58 Returns:
59 torch.tensor: MSE loss for data features and predicted SEM.
60 """
61 x_prime = dag_pred_with_factors(x_true, w_pred, factor_mat)
62 return (0.5 / x_true.shape[0]) * torch.square(torch.linalg.norm(x_true - x_prime))
65def linear_sem_loss(x_true, w_pred):
66 """Computes MSE loss between true x and predicted linear structural equation model,
67 for torch tensors. Works on batches only.
69 Args:
70 x_true (torch.FloatTensor): Vector of True features x
71 w_pred (torch.FloatTensor): Predicted linear structural equation model
73 Returns:
74 torch.tensor: MSE loss for data features and predicted SEM.
75 """
76 x_prime = dag_pred(x_true, w_pred)
77 return (0.5 / x_true.shape[0]) * torch.square(torch.linalg.norm(x_true - x_prime))