Coverage for contextualized/tests.py: 99%

165 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1import os 

2import unittest 

3import numpy as np 

4import torch 

5from contextualized.modules import SoftSelect, Explainer, MLP, NGAM, Linear 

6from contextualized.easy import ( 

7 ContextualizedRegressor, 

8 ContextualizedBayesianNetworks, 

9 ContextualizedCorrelationNetworks, 

10) 

11from contextualized.baselines import BayesianNetwork, CorrelationNetwork 

12from contextualized.utils import save, load 

13 

14 

15class TestModules(unittest.TestCase): 

16 

17 def __init__(self, *args, **kwargs): 

18 super().__init__(*args, **kwargs) 

19 

20 def setUp(self): 

21 """ 

22 Shared data setup. 

23 """ 

24 self.N_SAMPLES = 100 

25 self.X_DIM = 10 

26 self.Y_DIM = 5 

27 self.K_ARCHETYPES = 3 

28 self.WIDTH = 50 

29 self.LAYERS = 5 

30 self.X_data = torch.rand((self.N_SAMPLES, self.X_DIM)) 

31 self.IN_DIMS = (3, 4) 

32 self.OUT_SHAPE = (5, 6) 

33 self.Z1 = torch.randn(self.N_SAMPLES, self.IN_DIMS[0]) 

34 self.Z2 = torch.randn(self.N_SAMPLES, self.IN_DIMS[1]) 

35 

36 def test_mlp(self): 

37 """ 

38 Test that the output shape of the MLP is as expected. 

39 """ 

40 mlp = MLP(self.X_DIM, self.Y_DIM, self.WIDTH, self.LAYERS) 

41 assert mlp(self.X_data).shape == (self.N_SAMPLES, self.Y_DIM) 

42 

43 def test_ngam(self): 

44 """ 

45 Test that the output shape of the NGAM is as expected. 

46 """ 

47 ngam = NGAM(self.X_DIM, self.Y_DIM, self.WIDTH, self.LAYERS) 

48 assert ngam(self.X_data).shape == (self.N_SAMPLES, self.Y_DIM) 

49 

50 def test_softselect(self): 

51 """ 

52 Test that the output shape of the SoftSelect is as expected. 

53 """ 

54 softselect = SoftSelect(self.IN_DIMS, self.OUT_SHAPE) 

55 assert softselect(self.Z1, self.Z2).shape == (self.N_SAMPLES, *self.OUT_SHAPE) 

56 

57 precycle_vals = softselect.archetypes 

58 assert precycle_vals.shape == (*self.OUT_SHAPE, *self.IN_DIMS) 

59 postcycle_vals = softselect.get_archetypes() 

60 assert postcycle_vals.shape == (*self.IN_DIMS, *self.OUT_SHAPE) 

61 softselect.set_archetypes(torch.randn(*self.IN_DIMS, *self.OUT_SHAPE)) 

62 assert (softselect.archetypes != precycle_vals).any() 

63 softselect.set_archetypes(postcycle_vals) 

64 assert (softselect.archetypes == precycle_vals).all() 

65 

66 def test_explainer(self): 

67 explainer = Explainer(self.IN_DIMS[0], self.OUT_SHAPE) 

68 ret = explainer(self.Z1) 

69 

70 precycle_vals = explainer.archetypes 

71 assert precycle_vals.shape == (*self.OUT_SHAPE, self.IN_DIMS[0]) 

72 postcycle_vals = explainer.get_archetypes() 

73 assert postcycle_vals.shape == (self.IN_DIMS[0], *self.OUT_SHAPE) 

74 explainer.set_archetypes(torch.randn(self.IN_DIMS[0], *self.OUT_SHAPE)) 

75 assert (explainer.archetypes != precycle_vals).any() 

76 explainer.set_archetypes(postcycle_vals) 

77 assert (explainer.archetypes == precycle_vals).all() 

78 

79 def test_linear(self): 

80 linear_encoder = Linear(self.X_DIM, self.Y_DIM) 

81 linear_output = linear_encoder(self.X_data) 

82 assert linear_output.shape == (self.N_SAMPLES, self.Y_DIM) 

83 

84 

85class TestSaveLoad(unittest.TestCase): 

86 

87 def __init__(self, *args, **kwargs): 

88 super().__init__(*args, **kwargs) 

89 

90 def test_save_load(self): 

91 """ 

92 Test saving and loading of contextualized objects 

93 """ 

94 C = np.random.uniform(0, 1, size=(100, 2)) 

95 X = np.random.uniform(0, 1, size=(100, 2)) 

96 Y = np.random.uniform(0, 1, size=(100, 2)) 

97 C2 = np.random.uniform(0, 1, size=(100, 2)) 

98 X2 = np.random.uniform(0, 1, size=(100, 2)) 

99 Y2 = np.random.uniform(0, 1, size=(100, 2)) 

100 mlp = MLP(2, 2, 50, 5) 

101 Y_pred = mlp(torch.Tensor(X)).detach().numpy() 

102 save(mlp, "unittest_model.pt") 

103 del mlp 

104 mlp_loaded = load("unittest_model.pt") 

105 Y_pred_loaded = mlp_loaded(torch.Tensor(X)).detach().numpy() 

106 assert np.all(Y_pred == Y_pred_loaded) 

107 os.remove("unittest_model.pt") 

108 

109 model = ContextualizedRegressor() 

110 model.fit(C, X, Y) 

111 Y_pred = model.predict(C, X) 

112 save(model, "unittest_model.pt") 

113 del model 

114 model_loaded = load("unittest_model.pt") 

115 Y_pred_loaded = model_loaded.predict(C, X) 

116 assert np.all(Y_pred == Y_pred_loaded) 

117 os.remove("unittest_model.pt") 

118 model_loaded.fit(C2, X2, Y2) 

119 Y_pred2 = model_loaded.predict(C2, X2) 

120 assert not np.all(Y_pred_loaded == Y_pred2) 

121 save(model_loaded, "unittest_model.pt") 

122 del model_loaded 

123 model_loaded2 = load("unittest_model.pt") 

124 Y_pred_loaded2 = model_loaded2.predict(C2, X2) 

125 assert np.all(Y_pred2 == Y_pred_loaded2) 

126 os.remove("unittest_model.pt") 

127 

128 model = ContextualizedBayesianNetworks() 

129 model.fit(C, X) 

130 pred = model.predict_networks(C) 

131 save(model, "unittest_model.pt") 

132 del model 

133 model_loaded = load("unittest_model.pt") 

134 pred_loaded = model_loaded.predict_networks(C) 

135 assert np.all(np.array(pred) == np.array(pred_loaded)) 

136 os.remove("unittest_model.pt") 

137 model_loaded.fit(C2, X2) 

138 pred2 = model_loaded.predict_networks(C2) 

139 assert not np.all(np.array(pred_loaded) == np.array(pred2)) 

140 save(model_loaded, "unittest_model.pt") 

141 del model_loaded 

142 model_loaded2 = load("unittest_model.pt") 

143 pred_loaded2 = model_loaded2.predict_networks(C2) 

144 assert np.all(np.array(pred2) == np.array(pred_loaded2)) 

145 os.remove("unittest_model.pt") 

146 

147 model = ContextualizedCorrelationNetworks() 

148 model.fit(C, X) 

149 pred = model.predict_correlation(C) 

150 save(model, "unittest_model.pt") 

151 del model 

152 model_loaded = load("unittest_model.pt") 

153 pred_loaded = model_loaded.predict_correlation(C) 

154 assert np.all(np.array(pred) == np.array(pred_loaded)) 

155 os.remove("unittest_model.pt") 

156 model_loaded.fit(C2, X2) 

157 pred2 = model_loaded.predict_correlation(C2) 

158 assert not np.all(np.array(pred_loaded) == np.array(pred2)) 

159 save(model_loaded, "unittest_model.pt") 

160 del model_loaded 

161 model_loaded2 = load("unittest_model.pt") 

162 pred_loaded2 = model_loaded2.predict_correlation(C2) 

163 assert np.all(np.array(pred2) == np.array(pred_loaded2)) 

164 os.remove("unittest_model.pt") 

165 

166 model = BayesianNetwork() 

167 model.fit(X) 

168 pred = model.measure_mses(X) 

169 save(model, "unittest_model.pt") 

170 del model 

171 model_loaded = load("unittest_model.pt") 

172 pred_loaded = model_loaded.measure_mses(X) 

173 assert np.all(np.array(pred) == np.array(pred_loaded)) 

174 os.remove("unittest_model.pt") 

175 model_loaded.fit(X2) 

176 pred2 = model_loaded.measure_mses(X2) 

177 assert not np.all(np.array(pred_loaded) == np.array(pred2)) 

178 save(model_loaded, "unittest_model.pt") 

179 del model_loaded 

180 model_loaded2 = load("unittest_model.pt") 

181 pred_loaded2 = model_loaded2.measure_mses(X2) 

182 assert np.all(np.array(pred2) == np.array(pred_loaded2)) 

183 os.remove("unittest_model.pt") 

184 

185 model = CorrelationNetwork() 

186 model.fit(X) 

187 pred = model.measure_mses(X) 

188 save(model, "unittest_model.pt") 

189 del model 

190 model_loaded = load("unittest_model.pt") 

191 pred_loaded = model_loaded.measure_mses(X) 

192 assert np.all(np.array(pred) == np.array(pred_loaded)) 

193 os.remove("unittest_model.pt") 

194 model_loaded.fit(X2) 

195 pred2 = model_loaded.measure_mses(X2) 

196 assert not np.all(np.array(pred_loaded) == np.array(pred2)) 

197 save(model_loaded, "unittest_model.pt") 

198 del model_loaded 

199 model_loaded2 = load("unittest_model.pt") 

200 pred_loaded2 = model_loaded2.measure_mses(X2) 

201 assert np.all(np.array(pred2) == np.array(pred_loaded2)) 

202 os.remove("unittest_model.pt") 

203 

204 

205if __name__ == "__main__": 

206 unittest.main()