Coverage for contextualized/utils.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1""" 

2Utility functions, including saving/loading of contextualized models. 

3""" 

4 

5import torch 

6 

7 

8def save(model, path): 

9 """ 

10 Saves model to path. 

11 :param model: 

12 :param path: 

13 

14 """ 

15 with open(path, "wb") as out_file: 

16 torch.save(model, out_file) 

17 

18 

19def load(path): 

20 """ 

21 Loads model from path. 

22 :param path: 

23 

24 """ 

25 with open(path, "rb") as in_file: 

26 model = torch.load(in_file) 

27 return model 

28 

29 

30class DummyParamPredictor: 

31 """ 

32 Predicts Parameters as all zeros. 

33 """ 

34 

35 def __init__(self, beta_dim, mu_dim): 

36 self.beta_dim = beta_dim 

37 self.mu_dim = mu_dim 

38 

39 def predict_params(self, *args): 

40 """ 

41 

42 :param *args: 

43 

44 """ 

45 n = len(args[0]) 

46 return torch.zeros((n, *self.beta_dim)), torch.zeros((n, *self.mu_dim)) 

47 

48 

49class DummyYPredictor: 

50 """ 

51 Predicts Ys as all zeros. 

52 """ 

53 

54 def __init__(self, y_dim): 

55 self.y_dim = y_dim 

56 

57 def predict_y(self, *args): 

58 """ 

59 

60 :param *args: 

61 

62 """ 

63 n = len(args[0]) 

64 return torch.zeros((n, *self.y_dim))