Coverage for contextualized/easy/tests/test_correlation_networks.py: 97%

30 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 16:32 -0400

1"""" 

2Unit tests for Easy Networks. 

3""" 

4 

5import unittest 

6import numpy as np 

7import torch 

8from contextualized.easy import ContextualizedCorrelationNetworks 

9from contextualized.easy.tests.test_networks import TestEasyNetworks 

10 

11 

12class TestContextualizedCorrelationNetworks(TestEasyNetworks): 

13 """ 

14 Test Contextualized Correlation Network models. 

15 """ 

16 

17 def setUp(self): 

18 """ 

19 Shared unit test setup code. 

20 """ 

21 self.n_samples = 100 

22 self.c_dim = 4 

23 self.x_dim = 5 

24 C = torch.rand((self.n_samples, self.c_dim)) - 0.5 

25 # TODO: Use graph utils to generate X from a network. 

26 X = torch.rand((self.n_samples, self.x_dim)) - 0.5 

27 self.C, self.X = C.numpy(), X.numpy() 

28 

29 def test_correlation(self): 

30 """ 

31 Test Case for ContextualizedCorrelationNetworks. 

32 """ 

33 

34 model = ContextualizedCorrelationNetworks() 

35 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

36 self._quicktest( 

37 model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5 

38 ) 

39 

40 model = ContextualizedCorrelationNetworks(num_archetypes=16) 

41 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

42 

43 model = ContextualizedCorrelationNetworks( 

44 encoder_type="ngam", num_archetypes=16 

45 ) 

46 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

47 rho = model.predict_correlation(self.C, squared=False) 

48 assert rho.shape == (1, self.n_samples, self.x_dim, self.x_dim) 

49 rho = model.predict_correlation(self.C, individual_preds=False, squared=False) 

50 assert rho.shape == (self.n_samples, self.x_dim, self.x_dim), rho.shape 

51 rho_squared = model.predict_correlation(self.C, squared=True) 

52 assert np.min(rho_squared) >= 0 

53 assert rho_squared.shape == (1, self.n_samples, self.x_dim, self.x_dim) 

54 

55 

56if __name__ == "__main__": 

57 unittest.main()