Coverage for contextualized/easy/tests/test_networks.py: 71%

24 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 16:32 -0400

1"""" 

2Unit tests for Easy Networks. 

3""" 

4 

5import unittest 

6import numpy as np 

7import torch 

8 

9 

10class TestEasyNetworks(unittest.TestCase): 

11 """ 

12 Test Easy Network models. 

13 """ 

14 

15 def __init__(self, *args, **kwargs): 

16 super().__init__(*args, **kwargs) 

17 

18 def _quicktest(self, model, C, X, **kwargs): 

19 print(f"{type(model)} quicktest") 

20 model.fit(C, X, max_epochs=0) 

21 err_init = model.measure_mses(C, X) 

22 model.fit(C, X, **kwargs) 

23 err_trained = model.measure_mses(C, X) 

24 W_pred = model.predict_networks(C) 

25 assert W_pred.shape == (C.shape[0], X.shape[1], X.shape[1]) 

26 assert np.mean(err_trained) < np.mean(err_init) 

27 

28 def setUp(self): 

29 """ 

30 Shared unit test setup code. 

31 For subclasses, override this method to set up the test data. 

32 """ 

33 self.n_samples = 100 

34 self.c_dim = 4 

35 self.x_dim = 5 

36 C = torch.rand((self.n_samples, self.c_dim)) - 0.5 

37 X = torch.rand((self.n_samples, self.x_dim)) - 0.5 

38 self.C, self.X = C.numpy(), X.numpy() 

39 

40 

41if __name__ == "__main__": 

42 unittest.main()