Coverage for contextualized/functions.py: 81%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1"""Utility torch mathematical functions which are used for many modules. 

2""" 

3 

4import torch 

5import torch.nn.functional as F 

6from functools import partial 

7 

8 

9def zero_vector(x, *args): 

10 return torch.zeros((len(x), 1)) 

11 

12 

13def zero(x): 

14 return torch.zeros_like(x) 

15 

16 

17def identity(x): 

18 return x 

19 

20 

21def linear(x, slope, intercept): 

22 return x * slope + intercept 

23 

24 

25def logistic(x, slope, intercept): 

26 return 1 / (1 + torch.exp(-x * slope - intercept)) 

27 

28 

29def linear_link(x, slope, intercept): 

30 return x * slope + intercept 

31 

32 

33def identity_link(x): 

34 return x 

35 

36 

37def softmax_link(x, slope, intercept): 

38 return F.softmax(x * slope + intercept, dim=1) 

39 

40 

41def make_fn(base_fn, **params): 

42 """ 

43 Makes a single-parameter function from a base function class and a fixed 

44 set of extra parameters. 

45 :param base_fn: 

46 :param **params: 

47 

48 """ 

49 return partial(base_fn, **params) 

50 

51 

52def linear_constructor(slope=1, intercept=0): 

53 """ 

54 Creates a single-parameter linear function with slope m and offset b. 

55 :param slope: (Default value = 1) 

56 :param intercept: (Default value = 0) 

57 

58 """ 

59 return make_fn(linear, slope=slope, intercept=intercept) 

60 

61 

62def logistic_constructor(slope=1, intercept=0): 

63 """ 

64 Creates a single-parameter logistic function with slope m and offset b. 

65 :param slope: (Default value = 1) 

66 :param intercept: (Default value = 0) 

67 

68 """ 

69 return make_fn(logistic, slope=slope, intercept=intercept) 

70 

71 

72def identity_link_constructor(): 

73 """ 

74 Creates a single-parameter identity function. 

75 """ 

76 return make_fn(identity_link) 

77 

78 

79def linear_link_constructor(slope=1, intercept=0): 

80 """ 

81 Creates a single-parameter linear link function with slope m and offset b. 

82 :param slope: (Default value = 1) 

83 :param intercept: (Default value = 0) 

84 

85 """ 

86 return make_fn(linear_link, slope=slope, intercept=intercept) 

87 

88 

89def softmax_link_constructor(slope=1, intercept=0): 

90 """ 

91 Creates a single-parameter softmax link function with slope m and offset b. 

92 :param slope: (Default value = 1) 

93 :param intercept: (Default value = 0) 

94 

95 """ 

96 return make_fn(softmax_link, slope=slope, intercept=intercept) 

97 

98 

99LINK_FUNCTIONS = { 

100 "identity": linear_link_constructor(), 

101 "logistic": logistic_constructor(), 

102 "softmax": softmax_link_constructor(), 

103}