Coverage for contextualized/modules.py: 100%
65 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2PyTorch modules which are used as building blocks of Contextualized models.
3"""
5import torch
6from torch import nn
8from contextualized.functions import identity_link
11class SoftSelect(nn.Module):
12 """
13 Parameter sharing for multiple context encoders:
14 Batched computation for mapping many subtypes onto d-dimensional archetypes
15 """
17 def __init__(self, in_dims, out_shape):
18 super().__init__()
19 self.in_dims = in_dims
20 self.out_shape = out_shape
21 init_mat = torch.rand(list(out_shape) + list(in_dims)) * 2e-2 - 1e-2
22 self.archetypes = nn.parameter.Parameter(init_mat, requires_grad=True)
24 def forward(self, *batch_weights):
25 """Torch Forward pass."""
26 batch_size = batch_weights[0].shape[0]
27 expand_dims = [batch_size] + [-1 for _ in range(len(self.archetypes.shape))]
28 batch_archetypes = self.archetypes.unsqueeze(0).expand(expand_dims)
29 for batch_w in batch_weights[::-1]:
30 batch_w = batch_w.unsqueeze(-1)
31 empty_dims = len(batch_archetypes.shape) - len(batch_w.shape)
32 for _ in range(empty_dims):
33 batch_w = batch_w.unsqueeze(1)
34 batch_archetypes = torch.matmul(batch_archetypes, batch_w).squeeze(-1)
35 return batch_archetypes
37 def _cycle_dims(self, tensor, n_steps):
38 """
39 Cycle tensor dimensions from front to back for n steps
40 """
41 for _ in range(n_steps):
42 tensor = tensor.unsqueeze(0).transpose(0, -1).squeeze(-1)
43 return tensor
45 def get_archetypes(self):
46 """
47 Returns archetype parameters: (*in_dims, *out_shape)
48 """
49 return self._cycle_dims(self.archetypes, len(self.in_dims))
51 def set_archetypes(self, archetypes):
52 """
53 Sets archetype parameters
55 Requires archetypes.shape == (*in_dims, *out_shape)
56 """
57 self.archetypes = nn.parameter.Parameter(
58 self._cycle_dims(archetypes, len(self.out_shape)), requires_grad=True
59 )
62class Explainer(SoftSelect):
63 """
64 2D subtype-archetype parameter sharing
65 """
67 def __init__(self, k, out_shape):
68 super().__init__((k,), out_shape)
71class MLP(nn.Module):
72 """
73 Multi-layer perceptron
74 """
76 def __init__(
77 self,
78 input_dim,
79 output_dim,
80 width,
81 layers,
82 activation=nn.ReLU,
83 link_fn=identity_link,
84 ):
85 super().__init__()
86 if layers > 0:
87 mlp_layers = [nn.Linear(input_dim, width), activation()]
88 for _ in range(layers - 1):
89 mlp_layers += [nn.Linear(width, width), activation()]
90 mlp_layers.append(nn.Linear(width, output_dim))
91 else: # Linear encoder
92 mlp_layers = [nn.Linear(input_dim, output_dim)]
93 self.mlp = nn.Sequential(*mlp_layers)
94 self.link_fn = link_fn
96 def forward(self, X):
97 """Torch Forward pass."""
98 ret = self.mlp(X)
99 return self.link_fn(ret)
102class NGAM(nn.Module):
103 """
104 Neural generalized additive model
105 """
107 def __init__(
108 self,
109 input_dim,
110 output_dim,
111 width,
112 layers,
113 activation=nn.ReLU,
114 link_fn=identity_link,
115 ):
116 super().__init__()
117 self.intput_dim = input_dim
118 self.output_dim = output_dim
119 self.nams = nn.ModuleList(
120 [
121 MLP(
122 1,
123 output_dim,
124 width,
125 layers,
126 activation=activation,
127 link_fn=identity_link,
128 )
129 for _ in range(input_dim)
130 ]
131 )
132 self.link_fn = link_fn
134 def forward(self, X):
135 """Torch Forward pass."""
136 ret = self.nams[0](X[:, 0].unsqueeze(-1))
137 for i, nam in enumerate(self.nams[1:]):
138 ret += nam(X[:, i].unsqueeze(-1))
139 return self.link_fn(ret)
142class Linear(nn.Module):
143 """
144 Linear encoder
145 """
147 def __init__(self, input_dim, output_dim):
148 super().__init__()
149 self.linear = MLP(
150 input_dim, output_dim, width=output_dim, layers=0, activation=None
151 )
153 def forward(self, X):
154 """Torch Forward pass."""
155 return self.linear(X)
158ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear}