Coverage for contextualized/easy/tests/test_correlation_networks.py: 97%
30 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1""""
2Unit tests for Easy Networks.
3"""
5import unittest
6import numpy as np
7import torch
8from contextualized.easy import ContextualizedCorrelationNetworks
9from contextualized.easy.tests.test_networks import TestEasyNetworks
12class TestContextualizedCorrelationNetworks(TestEasyNetworks):
13 """
14 Test Contextualized Correlation Network models.
15 """
17 def setUp(self):
18 """
19 Shared unit test setup code.
20 """
21 self.n_samples = 100
22 self.c_dim = 4
23 self.x_dim = 5
24 C = torch.rand((self.n_samples, self.c_dim)) - 0.5
25 # TODO: Use graph utils to generate X from a network.
26 X = torch.rand((self.n_samples, self.x_dim)) - 0.5
27 self.C, self.X = C.numpy(), X.numpy()
29 def test_correlation(self):
30 """
31 Test Case for ContextualizedCorrelationNetworks.
32 """
34 model = ContextualizedCorrelationNetworks()
35 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
36 self._quicktest(
37 model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5
38 )
40 model = ContextualizedCorrelationNetworks(num_archetypes=16)
41 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
43 model = ContextualizedCorrelationNetworks(
44 encoder_type="ngam", num_archetypes=16
45 )
46 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
47 rho = model.predict_correlation(self.C, squared=False)
48 assert rho.shape == (1, self.n_samples, self.x_dim, self.x_dim)
49 rho = model.predict_correlation(self.C, individual_preds=False, squared=False)
50 assert rho.shape == (self.n_samples, self.x_dim, self.x_dim), rho.shape
51 rho_squared = model.predict_correlation(self.C, squared=True)
52 assert np.min(rho_squared) >= 0
53 assert rho_squared.shape == (1, self.n_samples, self.x_dim, self.x_dim)
56if __name__ == "__main__":
57 unittest.main()