Coverage for contextualized/regression/metamodels.py: 100%
76 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2Metamodels which generate context-specific models.
3"""
5import torch
6from torch import nn
8from contextualized.modules import ENCODERS, Explainer, SoftSelect
9from contextualized.functions import LINK_FUNCTIONS
12class NaiveMetamodel(nn.Module):
13 """Probabilistic assumptions as a graphical model (observed) {unobserved}:
14 (C) --> {beta, mu} --> (X, Y)
17 """
19 def __init__(
20 self,
21 context_dim,
22 x_dim,
23 y_dim,
24 univariate=False,
25 encoder_type="mlp",
26 encoder_kwargs={
27 "width": 25,
28 "layers": 1,
29 "link_fn": LINK_FUNCTIONS["identity"],
30 },
31 ):
32 """
33 context_dim (int): dimension of flattened context
34 x_dim (int): dimension of flattened features
35 y_dim (int): dimension of flattened labels
37 key-word args:
38 univariate (bool: False): flag to solve a univariate regression problem instead
39 of the standard multivariate problem
40 encoder_type (str: mlp): encoder module to use
41 encoder_kwargs (dict): encoder args and kwargs
42 """
43 super().__init__()
44 self.context_dim = context_dim
45 self.x_dim = x_dim
46 self.y_dim = y_dim
48 encoder = ENCODERS[encoder_type]
49 self.mu_dim = x_dim if univariate else 1
50 out_dim = (x_dim + self.mu_dim) * y_dim
51 self.context_encoder = encoder(context_dim, out_dim, **encoder_kwargs)
53 def forward(self, C):
54 """
56 :param C:
58 """
59 W = self.context_encoder(C)
60 W = torch.reshape(W, (W.shape[0], self.y_dim, self.x_dim + self.mu_dim))
61 beta = W[:, :, : self.x_dim]
62 mu = W[:, :, self.x_dim :]
63 return beta, mu
66class SubtypeMetamodel(nn.Module):
67 """Probabilistic assumptions as a graphical model (observed) {unobserved}:
68 (C) <-- {Z} --> {beta, mu} --> (X)
70 Z: latent variable, causal parent of both the context and regression model
73 """
75 def __init__(
76 self,
77 context_dim,
78 x_dim,
79 y_dim,
80 univariate=False,
81 num_archetypes=10,
82 encoder_type="mlp",
83 encoder_kwargs={
84 "width": 25,
85 "layers": 1,
86 "link_fn": LINK_FUNCTIONS["identity"],
87 },
88 ):
89 """
90 context_dim (int): dimension of flattened context
91 x_dim (int): dimension of flattened features
92 y_dim (int): dimension of flattened labels
94 key-word args:
95 univariate (bool: False): flag to solve a univariate regression problem instead
96 of the standard multivariate problem
97 num_archetypes (int: 10): number of atomic regression models in {Z}
98 encoder_type (str: mlp): encoder module to use
99 encoder_kwargs (dict): encoder args and kwargs
100 """
101 super().__init__()
102 self.context_dim = context_dim
103 self.x_dim = x_dim
104 self.y_dim = y_dim
106 encoder = ENCODERS[encoder_type]
107 out_shape = (y_dim, x_dim * 2, 1) if univariate else (y_dim, x_dim + 1)
108 self.context_encoder = encoder(context_dim, num_archetypes, **encoder_kwargs)
109 self.explainer = Explainer(num_archetypes, out_shape)
111 def forward(self, C):
112 """
114 :param C:
116 """
117 Z = self.context_encoder(C)
118 W = self.explainer(Z)
119 beta = W[:, :, : self.x_dim]
120 mu = W[:, :, self.x_dim :]
121 return beta, mu
124class MultitaskMetamodel(nn.Module):
125 """Probabilistic assumptions as a graphical model (observed) {unobserved}:
126 (C) <-- {Z} --> {beta, mu} --> (X)
127 (T) <---/
129 Z: latent variable, causal parent of the context, regression model, and task (T)
132 """
134 def __init__(
135 self,
136 context_dim,
137 x_dim,
138 y_dim,
139 univariate=False,
140 num_archetypes=10,
141 encoder_type="mlp",
142 encoder_kwargs={
143 "width": 25,
144 "layers": 1,
145 "link_fn": LINK_FUNCTIONS["identity"],
146 },
147 ):
148 """
149 context_dim (int): dimension of flattened context
150 x_dim (int): dimension of flattened features
151 y_dim (int): dimension of flattened labels
153 key-word args:
154 univariate (bool: False): flag to solve a univariate regression problem instead
155 of the standard multivariate problem
156 num_archetypes (int: 10): number of atomic regression models in {Z}
157 encoder_type (str: mlp): encoder module to use
158 encoder_kwargs (dict): encoder args and kwargs
159 """
160 super().__init__()
161 self.context_dim = context_dim
162 self.x_dim = x_dim
163 self.y_dim = y_dim
165 encoder = ENCODERS[encoder_type]
166 beta_dim = 1 if univariate else x_dim
167 task_dim = y_dim + x_dim if univariate else y_dim
168 self.context_encoder = encoder(
169 context_dim + task_dim, num_archetypes, **encoder_kwargs
170 )
171 self.explainer = Explainer(num_archetypes, (beta_dim + 1,))
173 def forward(self, C, T):
174 """
176 :param C:
177 :param T:
179 """
180 CT = torch.cat((C, T), 1)
181 Z = self.context_encoder(CT)
182 W = self.explainer(Z)
183 beta = W[:, :-1]
184 mu = W[:, -1:]
185 return beta, mu
188class TasksplitMetamodel(nn.Module):
189 """Probabilistic assumptions as a graphical model (observed) {unobserved}:
190 (C) <-- {Z_c} --> {beta, mu} --> (X)
191 (T) <-- {Z_t} ----^
193 Z_c: latent context variable, causal parent of the context and regression model
194 Z_t: latent task variable, causal parent of the task and regression model
197 """
199 def __init__(
200 self,
201 context_dim,
202 x_dim,
203 y_dim,
204 univariate=False,
205 context_archetypes=10,
206 task_archetypes=10,
207 context_encoder_type="mlp",
208 context_encoder_kwargs={
209 "width": 25,
210 "layers": 1,
211 "link_fn": LINK_FUNCTIONS["softmax"],
212 },
213 task_encoder_type="mlp",
214 task_encoder_kwargs={
215 "width": 25,
216 "layers": 1,
217 "link_fn": LINK_FUNCTIONS["identity"],
218 },
219 ):
220 """
221 context_dim (int): dimension of flattened context
222 x_dim (int): dimension of flattened features
223 y_dim (int): dimension of flattened labels
225 key-word args:
226 univariate (bool: False): flag to solve a univariate regression problem instead
227 of the standard multivariate problem
228 context_archetypes (int: 10): number of atomic regression models in {Z_c}
229 task_archetypes (int: 10): number of atomic regression models in {Z_t}
230 context_encoder_type (str: mlp): context encoder module to use
231 context_encoder_kwargs (dict): context encoder args and kwargs
232 task_encoder_type (str: mlp): task encoder module to use
233 task_encoder_kwargs (dict): task encoder args and kwargs
234 """
235 super().__init__()
236 self.context_dim = context_dim
237 self.x_dim = x_dim
238 self.y_dim = y_dim
240 context_encoder = ENCODERS[context_encoder_type]
241 task_encoder = ENCODERS[task_encoder_type]
242 beta_dim = 1 if univariate else x_dim
243 task_dim = y_dim + x_dim if univariate else y_dim
244 self.context_encoder = context_encoder(
245 context_dim, context_archetypes, **context_encoder_kwargs
246 )
247 self.task_encoder = task_encoder(
248 task_dim, task_archetypes, **task_encoder_kwargs
249 )
250 self.explainer = SoftSelect(
251 (context_archetypes, task_archetypes), (beta_dim + 1,)
252 )
254 def forward(self, C, T):
255 """
257 :param C:
258 :param T:
260 """
261 Z_c = self.context_encoder(C)
262 Z_t = self.task_encoder(T)
263 W = self.explainer(Z_c, Z_t)
264 beta = W[:, :-1]
265 mu = W[:, -1:]
266 return beta, mu
269SINGLE_TASK_METAMODELS = {
270 "naive": NaiveMetamodel,
271 "subtype": SubtypeMetamodel,
272}
274MULTITASK_METAMODELS = {
275 "multitask": MultitaskMetamodel,
276 "tasksplit": TasksplitMetamodel,
277}