Coverage for contextualized/easy/tests/test_markov_networks.py: 96%
27 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1""""
2Unit tests for Easy Networks.
3"""
5import unittest
6import numpy as np
7import torch
8from contextualized.easy.tests.test_networks import TestEasyNetworks
9from contextualized.easy import ContextualizedMarkovNetworks
12class TestContextualizedMarkovNetworks(TestEasyNetworks):
13 """
14 Test Contextualized Markov Network models.
15 """
17 def setUp(self):
18 """
19 Shared unit test setup code.
20 """
21 self.n_samples = 100
22 self.c_dim = 4
23 self.x_dim = 5
24 C = torch.rand((self.n_samples, self.c_dim)) - 0.5
25 # TODO: Use graph utils to generate X from a network.
26 X = torch.rand((self.n_samples, self.x_dim)) - 0.5
27 self.C, self.X = C.numpy(), X.numpy()
29 def test_markov(self):
30 """Test Case for ContextualizedMarkovNetworks."""
31 model = ContextualizedMarkovNetworks()
32 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
33 self._quicktest(
34 model, self.C, self.X, max_epochs=10, learning_rate=1e-3, val_split=0.5
35 )
37 model = ContextualizedMarkovNetworks(num_archetypes=16)
38 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
40 model = ContextualizedMarkovNetworks(encoder_type="ngam", num_archetypes=16)
41 self._quicktest(model, self.C, self.X, max_epochs=10, learning_rate=1e-3)
42 omegas = model.predict_precisions(self.C, individual_preds=False)
43 assert np.shape(omegas) == (self.n_samples, self.x_dim, self.x_dim)
44 omegas = model.predict_precisions(self.C, individual_preds=True)
45 assert np.shape(omegas) == (1, self.n_samples, self.x_dim, self.x_dim)
48if __name__ == "__main__":
49 unittest.main()