Coverage for contextualized/utils.py: 100%
21 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1"""
2Utility functions, including saving/loading of contextualized models.
3"""
5import torch
8def save(model, path):
9 """
10 Saves model to path.
11 :param model:
12 :param path:
14 """
15 with open(path, "wb") as out_file:
16 torch.save(model, out_file)
19def load(path):
20 """
21 Loads model from path.
22 :param path:
24 """
25 with open(path, "rb") as in_file:
26 model = torch.load(in_file)
27 return model
30class DummyParamPredictor:
31 """
32 Predicts Parameters as all zeros.
33 """
35 def __init__(self, beta_dim, mu_dim):
36 self.beta_dim = beta_dim
37 self.mu_dim = mu_dim
39 def predict_params(self, *args):
40 """
42 :param *args:
44 """
45 n = len(args[0])
46 return torch.zeros((n, *self.beta_dim)), torch.zeros((n, *self.mu_dim))
49class DummyYPredictor:
50 """
51 Predicts Ys as all zeros.
52 """
54 def __init__(self, y_dim):
55 self.y_dim = y_dim
57 def predict_y(self, *args):
58 """
60 :param *args:
62 """
63 n = len(args[0])
64 return torch.zeros((n, *self.y_dim))