Coverage for contextualized/regression/regularizers.py: 89%
19 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2Torch regularizers used for regression.
3"""
5import torch
6from functools import partial
9def no_reg_fn(beta, mu):
10 return 0.0
13def no_reg():
14 """
15 Function that returns an empty regularizer.
16 """
17 return no_reg_fn
20def l1_reg_fn(alpha, mu_ratio, beta, mu):
21 return (
22 alpha
23 * (
24 mu_ratio * torch.norm(mu, p=1) + (1 - mu_ratio) * torch.norm(beta, p=1)
25 ).mean()
26 )
29def l1_reg(alpha, mu_ratio=0.5):
30 """
32 :param alpha:
33 :param mu_ratio: (Default value = 0.5)
35 """
36 return partial(l1_reg_fn, alpha, mu_ratio)
39def l2_reg_fn(alpha, mu_ratio, beta, mu):
40 return (
41 alpha
42 * (
43 mu_ratio * torch.norm(mu, p=2) + (1 - mu_ratio) * torch.norm(beta, p=2)
44 ).mean()
45 )
48def l2_reg(alpha, mu_ratio=0.5):
49 """
51 :param alpha:
52 :param mu_ratio: (Default value = 0.5)
54 """
55 return partial(l2_reg_fn, alpha, mu_ratio)
58def l1_l2_reg_fn(alpha, l1_ratio, mu_ratio, beta, mu):
59 return (
60 alpha
61 * (
62 l1_ratio
63 * (mu_ratio * torch.norm(mu, p=1) + (1 - mu_ratio) * torch.norm(beta, p=1))
64 + (1 - l1_ratio)
65 * (mu_ratio * torch.norm(mu, p=2) + (1 - mu_ratio) * torch.norm(beta, p=2))
66 ).mean()
67 )
70def l1_l2_reg(alpha, l1_ratio=0.5, mu_ratio=0.5):
71 """
73 :param alpha:
74 :param l1_ratio: (Default value = 0.5)
75 :param mu_ratio: (Default value = 0.5)
77 """
78 return partial(l1_l2_reg_fn, alpha, l1_ratio, mu_ratio)
81REGULARIZERS = {"none": no_reg(), "l1": l1_reg, "l2": l2_reg, "l1_l2": l1_l2_reg}