Coverage for contextualized/easy/tests/test_markov_networks.py: 96%

27 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 16:32 -0400

1"""" 

2Unit tests for Easy Networks. 

3""" 

4 

5import unittest 

6import numpy as np 

7import torch 

8from contextualized.easy.tests.test_networks import TestEasyNetworks 

9from contextualized.easy import ContextualizedMarkovNetworks 

10 

11 

12class TestContextualizedMarkovNetworks(TestEasyNetworks): 

13 """ 

14 Test Contextualized Markov Network models. 

15 """ 

16 

17 def setUp(self): 

18 """ 

19 Shared unit test setup code. 

20 """ 

21 self.n_samples = 100 

22 self.c_dim = 4 

23 self.x_dim = 5 

24 C = torch.rand((self.n_samples, self.c_dim)) - 0.5 

25 # TODO: Use graph utils to generate X from a network. 

26 X = torch.rand((self.n_samples, self.x_dim)) - 0.5 

27 self.C, self.X = C.numpy(), X.numpy() 

28 

29 def test_markov(self): 

30 """Test Case for ContextualizedMarkovNetworks.""" 

31 model = ContextualizedMarkovNetworks() 

32 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

33 self._quicktest( 

34 model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5 

35 ) 

36 

37 model = ContextualizedMarkovNetworks(num_archetypes=16) 

38 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

39 

40 model = ContextualizedMarkovNetworks(encoder_type="ngam", num_archetypes=16) 

41 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

42 omegas = model.predict_precisions(self.C, individual_preds=False) 

43 assert np.shape(omegas) == (self.n_samples, self.x_dim, self.x_dim) 

44 omegas = model.predict_precisions(self.C, individual_preds=True) 

45 assert np.shape(omegas) == (1, self.n_samples, self.x_dim, self.x_dim) 

46 

47 

48if __name__ == "__main__": 

49 unittest.main()