Coverage for contextualized/functions.py: 81%
32 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1"""Utility torch mathematical functions which are used for many modules.
2"""
4import torch
5import torch.nn.functional as F
6from functools import partial
9def zero_vector(x, *args):
10 return torch.zeros((len(x), 1))
13def zero(x):
14 return torch.zeros_like(x)
17def identity(x):
18 return x
21def linear(x, slope, intercept):
22 return x * slope + intercept
25def logistic(x, slope, intercept):
26 return 1 / (1 + torch.exp(-x * slope - intercept))
29def linear_link(x, slope, intercept):
30 return x * slope + intercept
33def identity_link(x):
34 return x
37def softmax_link(x, slope, intercept):
38 return F.softmax(x * slope + intercept, dim=1)
41def make_fn(base_fn, **params):
42 """
43 Makes a single-parameter function from a base function class and a fixed
44 set of extra parameters.
45 :param base_fn:
46 :param **params:
48 """
49 return partial(base_fn, **params)
52def linear_constructor(slope=1, intercept=0):
53 """
54 Creates a single-parameter linear function with slope m and offset b.
55 :param slope: (Default value = 1)
56 :param intercept: (Default value = 0)
58 """
59 return make_fn(linear, slope=slope, intercept=intercept)
62def logistic_constructor(slope=1, intercept=0):
63 """
64 Creates a single-parameter logistic function with slope m and offset b.
65 :param slope: (Default value = 1)
66 :param intercept: (Default value = 0)
68 """
69 return make_fn(logistic, slope=slope, intercept=intercept)
72def identity_link_constructor():
73 """
74 Creates a single-parameter identity function.
75 """
76 return make_fn(identity_link)
79def linear_link_constructor(slope=1, intercept=0):
80 """
81 Creates a single-parameter linear link function with slope m and offset b.
82 :param slope: (Default value = 1)
83 :param intercept: (Default value = 0)
85 """
86 return make_fn(linear_link, slope=slope, intercept=intercept)
89def softmax_link_constructor(slope=1, intercept=0):
90 """
91 Creates a single-parameter softmax link function with slope m and offset b.
92 :param slope: (Default value = 1)
93 :param intercept: (Default value = 0)
95 """
96 return make_fn(softmax_link, slope=slope, intercept=intercept)
99LINK_FUNCTIONS = {
100 "identity": linear_link_constructor(),
101 "logistic": logistic_constructor(),
102 "softmax": softmax_link_constructor(),
103}