Coverage for contextualized/easy/tests/test_networks.py: 71%
24 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1""""
2Unit tests for Easy Networks.
3"""
5import unittest
6import numpy as np
7import torch
10class TestEasyNetworks(unittest.TestCase):
11 """
12 Test Easy Network models.
13 """
15 def __init__(self, *args, **kwargs):
16 super().__init__(*args, **kwargs)
18 def _quicktest(self, model, C, X, **kwargs):
19 print(f"{type(model)} quicktest")
20 model.fit(C, X, max_epochs=0)
21 err_init = model.measure_mses(C, X)
22 model.fit(C, X, **kwargs)
23 err_trained = model.measure_mses(C, X)
24 W_pred = model.predict_networks(C)
25 assert W_pred.shape == (C.shape[0], X.shape[1], X.shape[1])
26 assert np.mean(err_trained) < np.mean(err_init)
28 def setUp(self):
29 """
30 Shared unit test setup code.
31 For subclasses, override this method to set up the test data.
32 """
33 self.n_samples = 100
34 self.c_dim = 4
35 self.x_dim = 5
36 C = torch.rand((self.n_samples, self.c_dim)) - 0.5
37 X = torch.rand((self.n_samples, self.x_dim)) - 0.5
38 self.C, self.X = C.numpy(), X.numpy()
41if __name__ == "__main__":
42 unittest.main()