Coverage for contextualized/easy/tests/test_bayesian_networks.py: 98%

51 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 16:32 -0400

1"""" 

2Unit tests for Easy Networks. 

3""" 

4 

5import unittest 

6import numpy as np 

7import torch 

8from contextualized.easy import ContextualizedBayesianNetworks 

9from contextualized.easy.tests.test_networks import TestEasyNetworks 

10 

11 

12class TestContextualizedBayesianNetworks(TestEasyNetworks): 

13 """ 

14 Test Contextualized Bayesian Network models. 

15 """ 

16 

17 def setUp(self): 

18 """ 

19 Shared unit test setup code. 

20 """ 

21 self.n_samples = 100 

22 self.c_dim = 4 

23 self.x_dim = 5 

24 C = torch.rand((self.n_samples, self.c_dim)) - 0.5 

25 # TODO: Use graph utils to generate X from a network. 

26 X = torch.rand((self.n_samples, self.x_dim)) - 0.5 

27 self.C, self.X = C.numpy(), X.numpy() 

28 

29 def test_bayesian_factors(self): 

30 """Test case for ContextualizedBayesianNetworks.""" 

31 model = ContextualizedBayesianNetworks( 

32 encoder_type="ngam", num_archetypes=16, num_factors=2 

33 ) 

34 model.fit(self.C, self.X, max_epochs=10) 

35 networks = model.predict_networks(self.C, individual_preds=False) 

36 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim) 

37 networks = model.predict_networks(self.C, factors=True) 

38 assert np.shape(networks) == (self.n_samples, 2, 2) 

39 model = ContextualizedBayesianNetworks( 

40 encoder_type="ngam", num_archetypes=16, num_factors=2 

41 ) 

42 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

43 

44 def test_bayesian_default(self): 

45 model = ContextualizedBayesianNetworks() 

46 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

47 

48 def test_bayesian_val_split(self): 

49 model = ContextualizedBayesianNetworks() 

50 self._quicktest( 

51 model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5 

52 ) 

53 

54 def test_bayesian_archetypes(self): 

55 model = ContextualizedBayesianNetworks(num_archetypes=16) 

56 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

57 

58 def test_bayesian_encoder(self): 

59 model = ContextualizedBayesianNetworks(encoder_type="ngam", num_archetypes=16) 

60 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

61 networks = model.predict_networks(self.C, individual_preds=False) 

62 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim) 

63 

64 model = ContextualizedBayesianNetworks(encoder_type="mlp", num_archetypes=16) 

65 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

66 networks = model.predict_networks(self.C, individual_preds=False) 

67 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim) 

68 

69 def test_bayesian_acyclicity(self): 

70 model = ContextualizedBayesianNetworks( 

71 archetype_dag_loss_type="DAGMA", num_archetypes=16 

72 ) 

73 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

74 networks = model.predict_networks(self.C, individual_preds=False) 

75 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim) 

76 

77 model = ContextualizedBayesianNetworks( 

78 archetype_dag_loss_type="poly", num_archetypes=16 

79 ) 

80 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3) 

81 networks = model.predict_networks(self.C, individual_preds=False) 

82 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim) 

83 

84 

85if __name__ == "__main__": 

86 unittest.main()