Coverage for contextualized/regression/metamodels.py: 100%

76 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Metamodels which generate context-specific models. 

3""" 

4 

5import torch 

6from torch import nn 

7 

8from contextualized.modules import ENCODERS, Explainer, SoftSelect 

9from contextualized.functions import LINK_FUNCTIONS 

10 

11 

12class NaiveMetamodel(nn.Module): 

13 """Probabilistic assumptions as a graphical model (observed) {unobserved}: 

14 (C) --> {beta, mu} --> (X, Y) 

15 

16 

17 """ 

18 

19 def __init__( 

20 self, 

21 context_dim, 

22 x_dim, 

23 y_dim, 

24 univariate=False, 

25 encoder_type="mlp", 

26 encoder_kwargs={ 

27 "width": 25, 

28 "layers": 1, 

29 "link_fn": LINK_FUNCTIONS["identity"], 

30 }, 

31 ): 

32 """ 

33 context_dim (int): dimension of flattened context 

34 x_dim (int): dimension of flattened features 

35 y_dim (int): dimension of flattened labels 

36 

37 key-word args: 

38 univariate (bool: False): flag to solve a univariate regression problem instead 

39 of the standard multivariate problem 

40 encoder_type (str: mlp): encoder module to use 

41 encoder_kwargs (dict): encoder args and kwargs 

42 """ 

43 super().__init__() 

44 self.context_dim = context_dim 

45 self.x_dim = x_dim 

46 self.y_dim = y_dim 

47 

48 encoder = ENCODERS[encoder_type] 

49 self.mu_dim = x_dim if univariate else 1 

50 out_dim = (x_dim + self.mu_dim) * y_dim 

51 self.context_encoder = encoder(context_dim, out_dim, **encoder_kwargs) 

52 

53 def forward(self, C): 

54 """ 

55 

56 :param C: 

57 

58 """ 

59 W = self.context_encoder(C) 

60 W = torch.reshape(W, (W.shape[0], self.y_dim, self.x_dim + self.mu_dim)) 

61 beta = W[:, :, : self.x_dim] 

62 mu = W[:, :, self.x_dim :] 

63 return beta, mu 

64 

65 

66class SubtypeMetamodel(nn.Module): 

67 """Probabilistic assumptions as a graphical model (observed) {unobserved}: 

68 (C) <-- {Z} --> {beta, mu} --> (X) 

69 

70 Z: latent variable, causal parent of both the context and regression model 

71 

72 

73 """ 

74 

75 def __init__( 

76 self, 

77 context_dim, 

78 x_dim, 

79 y_dim, 

80 univariate=False, 

81 num_archetypes=10, 

82 encoder_type="mlp", 

83 encoder_kwargs={ 

84 "width": 25, 

85 "layers": 1, 

86 "link_fn": LINK_FUNCTIONS["identity"], 

87 }, 

88 ): 

89 """ 

90 context_dim (int): dimension of flattened context 

91 x_dim (int): dimension of flattened features 

92 y_dim (int): dimension of flattened labels 

93 

94 key-word args: 

95 univariate (bool: False): flag to solve a univariate regression problem instead 

96 of the standard multivariate problem 

97 num_archetypes (int: 10): number of atomic regression models in {Z} 

98 encoder_type (str: mlp): encoder module to use 

99 encoder_kwargs (dict): encoder args and kwargs 

100 """ 

101 super().__init__() 

102 self.context_dim = context_dim 

103 self.x_dim = x_dim 

104 self.y_dim = y_dim 

105 

106 encoder = ENCODERS[encoder_type] 

107 out_shape = (y_dim, x_dim * 2, 1) if univariate else (y_dim, x_dim + 1) 

108 self.context_encoder = encoder(context_dim, num_archetypes, **encoder_kwargs) 

109 self.explainer = Explainer(num_archetypes, out_shape) 

110 

111 def forward(self, C): 

112 """ 

113 

114 :param C: 

115 

116 """ 

117 Z = self.context_encoder(C) 

118 W = self.explainer(Z) 

119 beta = W[:, :, : self.x_dim] 

120 mu = W[:, :, self.x_dim :] 

121 return beta, mu 

122 

123 

124class MultitaskMetamodel(nn.Module): 

125 """Probabilistic assumptions as a graphical model (observed) {unobserved}: 

126 (C) <-- {Z} --> {beta, mu} --> (X) 

127 (T) <---/ 

128 

129 Z: latent variable, causal parent of the context, regression model, and task (T) 

130 

131 

132 """ 

133 

134 def __init__( 

135 self, 

136 context_dim, 

137 x_dim, 

138 y_dim, 

139 univariate=False, 

140 num_archetypes=10, 

141 encoder_type="mlp", 

142 encoder_kwargs={ 

143 "width": 25, 

144 "layers": 1, 

145 "link_fn": LINK_FUNCTIONS["identity"], 

146 }, 

147 ): 

148 """ 

149 context_dim (int): dimension of flattened context 

150 x_dim (int): dimension of flattened features 

151 y_dim (int): dimension of flattened labels 

152 

153 key-word args: 

154 univariate (bool: False): flag to solve a univariate regression problem instead 

155 of the standard multivariate problem 

156 num_archetypes (int: 10): number of atomic regression models in {Z} 

157 encoder_type (str: mlp): encoder module to use 

158 encoder_kwargs (dict): encoder args and kwargs 

159 """ 

160 super().__init__() 

161 self.context_dim = context_dim 

162 self.x_dim = x_dim 

163 self.y_dim = y_dim 

164 

165 encoder = ENCODERS[encoder_type] 

166 beta_dim = 1 if univariate else x_dim 

167 task_dim = y_dim + x_dim if univariate else y_dim 

168 self.context_encoder = encoder( 

169 context_dim + task_dim, num_archetypes, **encoder_kwargs 

170 ) 

171 self.explainer = Explainer(num_archetypes, (beta_dim + 1,)) 

172 

173 def forward(self, C, T): 

174 """ 

175 

176 :param C: 

177 :param T: 

178 

179 """ 

180 CT = torch.cat((C, T), 1) 

181 Z = self.context_encoder(CT) 

182 W = self.explainer(Z) 

183 beta = W[:, :-1] 

184 mu = W[:, -1:] 

185 return beta, mu 

186 

187 

188class TasksplitMetamodel(nn.Module): 

189 """Probabilistic assumptions as a graphical model (observed) {unobserved}: 

190 (C) <-- {Z_c} --> {beta, mu} --> (X) 

191 (T) <-- {Z_t} ----^ 

192 

193 Z_c: latent context variable, causal parent of the context and regression model 

194 Z_t: latent task variable, causal parent of the task and regression model 

195 

196 

197 """ 

198 

199 def __init__( 

200 self, 

201 context_dim, 

202 x_dim, 

203 y_dim, 

204 univariate=False, 

205 context_archetypes=10, 

206 task_archetypes=10, 

207 context_encoder_type="mlp", 

208 context_encoder_kwargs={ 

209 "width": 25, 

210 "layers": 1, 

211 "link_fn": LINK_FUNCTIONS["softmax"], 

212 }, 

213 task_encoder_type="mlp", 

214 task_encoder_kwargs={ 

215 "width": 25, 

216 "layers": 1, 

217 "link_fn": LINK_FUNCTIONS["identity"], 

218 }, 

219 ): 

220 """ 

221 context_dim (int): dimension of flattened context 

222 x_dim (int): dimension of flattened features 

223 y_dim (int): dimension of flattened labels 

224 

225 key-word args: 

226 univariate (bool: False): flag to solve a univariate regression problem instead 

227 of the standard multivariate problem 

228 context_archetypes (int: 10): number of atomic regression models in {Z_c} 

229 task_archetypes (int: 10): number of atomic regression models in {Z_t} 

230 context_encoder_type (str: mlp): context encoder module to use 

231 context_encoder_kwargs (dict): context encoder args and kwargs 

232 task_encoder_type (str: mlp): task encoder module to use 

233 task_encoder_kwargs (dict): task encoder args and kwargs 

234 """ 

235 super().__init__() 

236 self.context_dim = context_dim 

237 self.x_dim = x_dim 

238 self.y_dim = y_dim 

239 

240 context_encoder = ENCODERS[context_encoder_type] 

241 task_encoder = ENCODERS[task_encoder_type] 

242 beta_dim = 1 if univariate else x_dim 

243 task_dim = y_dim + x_dim if univariate else y_dim 

244 self.context_encoder = context_encoder( 

245 context_dim, context_archetypes, **context_encoder_kwargs 

246 ) 

247 self.task_encoder = task_encoder( 

248 task_dim, task_archetypes, **task_encoder_kwargs 

249 ) 

250 self.explainer = SoftSelect( 

251 (context_archetypes, task_archetypes), (beta_dim + 1,) 

252 ) 

253 

254 def forward(self, C, T): 

255 """ 

256 

257 :param C: 

258 :param T: 

259 

260 """ 

261 Z_c = self.context_encoder(C) 

262 Z_t = self.task_encoder(T) 

263 W = self.explainer(Z_c, Z_t) 

264 beta = W[:, :-1] 

265 mu = W[:, -1:] 

266 return beta, mu 

267 

268 

269SINGLE_TASK_METAMODELS = { 

270 "naive": NaiveMetamodel, 

271 "subtype": SubtypeMetamodel, 

272} 

273 

274MULTITASK_METAMODELS = { 

275 "multitask": MultitaskMetamodel, 

276 "tasksplit": TasksplitMetamodel, 

277}