Coverage for contextualized/easy/tests/test_bayesian_networks.py: 98%
51 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1""""
2Unit tests for Easy Networks.
3"""
5import unittest
6import numpy as np
7import torch
8from contextualized.easy import ContextualizedBayesianNetworks
9from contextualized.easy.tests.test_networks import TestEasyNetworks
12class TestContextualizedBayesianNetworks(TestEasyNetworks):
13 """
14 Test Contextualized Bayesian Network models.
15 """
17 def setUp(self):
18 """
19 Shared unit test setup code.
20 """
21 self.n_samples = 100
22 self.c_dim = 4
23 self.x_dim = 5
24 C = torch.rand((self.n_samples, self.c_dim)) - 0.5
25 # TODO: Use graph utils to generate X from a network.
26 X = torch.rand((self.n_samples, self.x_dim)) - 0.5
27 self.C, self.X = C.numpy(), X.numpy()
29 def test_bayesian_factors(self):
30 """Test case for ContextualizedBayesianNetworks."""
31 model = ContextualizedBayesianNetworks(
32 encoder_type="ngam", num_archetypes=16, num_factors=2
33 )
34 model.fit(self.C, self.X, max_epochs=10)
35 networks = model.predict_networks(self.C, individual_preds=False)
36 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
37 networks = model.predict_networks(self.C, factors=True)
38 assert np.shape(networks) == (self.n_samples, 2, 2)
39 model = ContextualizedBayesianNetworks(
40 encoder_type="ngam", num_archetypes=16, num_factors=2
41 )
42 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
44 def test_bayesian_default(self):
45 model = ContextualizedBayesianNetworks()
46 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
48 def test_bayesian_val_split(self):
49 model = ContextualizedBayesianNetworks()
50 self._quicktest(
51 model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5
52 )
54 def test_bayesian_archetypes(self):
55 model = ContextualizedBayesianNetworks(num_archetypes=16)
56 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
58 def test_bayesian_encoder(self):
59 model = ContextualizedBayesianNetworks(encoder_type="ngam", num_archetypes=16)
60 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
61 networks = model.predict_networks(self.C, individual_preds=False)
62 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
64 model = ContextualizedBayesianNetworks(encoder_type="mlp", num_archetypes=16)
65 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
66 networks = model.predict_networks(self.C, individual_preds=False)
67 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
69 def test_bayesian_acyclicity(self):
70 model = ContextualizedBayesianNetworks(
71 archetype_dag_loss_type="DAGMA", num_archetypes=16
72 )
73 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
74 networks = model.predict_networks(self.C, individual_preds=False)
75 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
77 model = ContextualizedBayesianNetworks(
78 archetype_dag_loss_type="poly", num_archetypes=16
79 )
80 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
81 networks = model.predict_networks(self.C, individual_preds=False)
82 assert np.shape(networks) == (self.n_samples, self.x_dim, self.x_dim)
85if __name__ == "__main__":
86 unittest.main()