Coverage for contextualized/dags/tests_fast.py: 98%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 16:32 -0400

1""" 

2Unit tests for DAG models. 

3""" 

4import unittest 

5import numpy as np 

6from pytorch_lightning import seed_everything 

7from pytorch_lightning.callbacks import LearningRateFinder 

8 

9 

10from contextualized.dags.lightning_modules import NOTMAD 

11from contextualized.dags import graph_utils 

12from contextualized.dags.trainers import GraphTrainer 

13from contextualized.dags.losses import mse_loss as mse 

14 

15 

16class TestNOTMAD(unittest.TestCase): 

17 """Unit tests for NOTMAD.""" 

18 

19 def __init__(self, *args, **kwargs): 

20 super(TestNOTMAD, self).__init__(*args, **kwargs) 

21 

22 def setUp(self): 

23 seed_everything(0) 

24 self.n = 10 

25 self.c_dim = 4 

26 self.x_dim = 3 

27 self.C = np.random.uniform(-1, 1, size=(self.n, self.c_dim)) 

28 self.X = np.random.uniform(-1, 1, size=(self.n, self.x_dim)) 

29 

30 def _train(self, model_args, n_epochs): 

31 k = 6 

32 INIT_MAT = np.random.uniform(-0.1, 0.1, size=(k, 4, 4)) * np.tile( 

33 1 - np.eye(4), (k, 1, 1) 

34 ) 

35 model = NOTMAD( 

36 self.C.shape[-1], 

37 self.X.shape[-1], 

38 archetype_params={ 

39 "l1": 0.0, 

40 "dag": model_args.get( 

41 "dag", 

42 { 

43 "loss_type": "NOTEARS", 

44 "params": { 

45 "alpha": 1e-1, 

46 "rho": 1e-2, 

47 "h_old": 0.0, 

48 "tol": 0.25, 

49 "use_dynamic_alpha_rho": True, 

50 }, 

51 }, 

52 ), 

53 "init_mat": INIT_MAT, 

54 "num_factors": model_args.get("num_factors", 0), 

55 "factor_mat_l1": 0.0, 

56 "num_archetypes": model_args.get("num_archetypes", k), 

57 }, 

58 ) 

59 dataloader = model.dataloader( 

60 self.C, self.X, batch_size=1, num_workers=0 

61 ) 

62 trainer = GraphTrainer( 

63 max_epochs=n_epochs, deterministic=True, enable_progress_bar=False 

64 ) 

65 predict_trainer = GraphTrainer( 

66 deterministic=True, enable_progress_bar=False, devices=1 

67 ) 

68 init_preds = predict_trainer.predict_params( 

69 model, dataloader, project_to_dag=False, 

70 ) 

71 assert init_preds.shape == (self.n, self.x_dim, self.x_dim) 

72 init_mse = mse(graph_utils.dag_pred_np(self.X, init_preds), self.X) 

73 trainer.fit(model, dataloader) 

74 final_preds = predict_trainer.predict_params( 

75 model, dataloader, project_to_dag=False 

76 ) 

77 assert final_preds.shape == (self.n, self.x_dim, self.x_dim) 

78 final_mse = mse(graph_utils.dag_pred_np(self.X, final_preds), self.X) 

79 assert final_mse < init_mse 

80 

81 def test_notmad_dagma(self): 

82 self._train( 

83 { 

84 "dag": { 

85 "loss_type": "DAGMA", 

86 "params": { 

87 "alpha": 1.0, 

88 }, 

89 } 

90 }, 

91 1, 

92 ) 

93 

94 def test_notmad_notears(self): 

95 self._train({}, 1) 

96 

97 def test_notmad_factor_graphs(self): 

98 """ 

99 Unit tests for factor graph feature of NOTMAD. 

100 """ 

101 self._train({"num_factors": 3}, 1) 

102 

103 

104if __name__ == "__main__": 

105 unittest.main()