Coverage for contextualized/regression/regularizers.py: 89%

19 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Torch regularizers used for regression. 

3""" 

4 

5import torch 

6from functools import partial 

7 

8 

9def no_reg_fn(beta, mu): 

10 return 0.0 

11 

12 

13def no_reg(): 

14 """ 

15 Function that returns an empty regularizer. 

16 """ 

17 return no_reg_fn 

18 

19 

20def l1_reg_fn(alpha, mu_ratio, beta, mu): 

21 return ( 

22 alpha 

23 * ( 

24 mu_ratio * torch.norm(mu, p=1) + (1 - mu_ratio) * torch.norm(beta, p=1) 

25 ).mean() 

26 ) 

27 

28 

29def l1_reg(alpha, mu_ratio=0.5): 

30 """ 

31 

32 :param alpha: 

33 :param mu_ratio: (Default value = 0.5) 

34 

35 """ 

36 return partial(l1_reg_fn, alpha, mu_ratio) 

37 

38 

39def l2_reg_fn(alpha, mu_ratio, beta, mu): 

40 return ( 

41 alpha 

42 * ( 

43 mu_ratio * torch.norm(mu, p=2) + (1 - mu_ratio) * torch.norm(beta, p=2) 

44 ).mean() 

45 ) 

46 

47 

48def l2_reg(alpha, mu_ratio=0.5): 

49 """ 

50 

51 :param alpha: 

52 :param mu_ratio: (Default value = 0.5) 

53 

54 """ 

55 return partial(l2_reg_fn, alpha, mu_ratio) 

56 

57 

58def l1_l2_reg_fn(alpha, l1_ratio, mu_ratio, beta, mu): 

59 return ( 

60 alpha 

61 * ( 

62 l1_ratio 

63 * (mu_ratio * torch.norm(mu, p=1) + (1 - mu_ratio) * torch.norm(beta, p=1)) 

64 + (1 - l1_ratio) 

65 * (mu_ratio * torch.norm(mu, p=2) + (1 - mu_ratio) * torch.norm(beta, p=2)) 

66 ).mean() 

67 ) 

68 

69 

70def l1_l2_reg(alpha, l1_ratio=0.5, mu_ratio=0.5): 

71 """ 

72 

73 :param alpha: 

74 :param l1_ratio: (Default value = 0.5) 

75 :param mu_ratio: (Default value = 0.5) 

76 

77 """ 

78 return partial(l1_l2_reg_fn, alpha, l1_ratio, mu_ratio) 

79 

80 

81REGULARIZERS = {"none": no_reg(), "l1": l1_reg, "l2": l2_reg, "l1_l2": l1_l2_reg}