Coverage for contextualized/regression/datasets.py: 93%
86 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2Data generators used for Contextualized regression training.
3"""
5from abc import abstractmethod
6import torch
7from torch.utils.data import IterableDataset
10class Dataset:
11 """Superclass for datastreams (iterators) used to train contextualized.regression models"""
13 def __init__(self, C, X, Y, dtype=torch.float):
14 self.C = torch.tensor(C, dtype=dtype)
15 self.X = torch.tensor(X, dtype=dtype)
16 self.Y = torch.tensor(Y, dtype=dtype)
17 self.n_i = 0
18 self.x_i = 0
19 self.y_i = 0
20 self.n = C.shape[0]
21 self.c_dim = C.shape[-1]
22 self.x_dim = X.shape[-1]
23 self.y_dim = Y.shape[-1]
24 self.dtype = dtype
26 def __iter__(self):
27 self.n_i = 0
28 self.x_i = 0
29 self.y_i = 0
30 return self
32 @abstractmethod
33 def __next__(self):
34 pass
36 @abstractmethod
37 def __len__(self):
38 pass
41class MultivariateDataset(Dataset):
42 """
43 Simple multivariate dataset with context, predictors, and outcomes.
44 """
46 def __next__(self):
47 if self.n_i >= self.n:
48 self.n_i = 0
49 raise StopIteration
50 ret = (
51 self.C[self.n_i],
52 self.X[self.n_i].expand(self.y_dim, -1),
53 self.Y[self.n_i].unsqueeze(-1),
54 self.n_i,
55 )
56 self.n_i += 1
57 return ret
59 def __len__(self):
60 return self.n
63class UnivariateDataset(Dataset):
64 """
65 Simple univariate dataset with context, predictors, and one outcome.
66 """
68 def __next__(self):
69 if self.n_i >= self.n:
70 self.n_i = 0
71 raise StopIteration
72 ret = (
73 self.C[self.n_i],
74 self.X[self.n_i].expand(self.y_dim, -1).unsqueeze(-1),
75 self.Y[self.n_i].expand(self.x_dim, -1).T.unsqueeze(-1),
76 self.n_i,
77 )
78 self.n_i += 1
79 return ret
81 def __len__(self):
82 return self.n
85class MultitaskMultivariateDataset(Dataset):
86 """
87 Multi-task Multivariate Dataset.
88 """
90 def __next__(self):
91 if self.y_i >= self.y_dim:
92 self.n_i += 1
93 self.y_i = 0
94 if self.n_i >= self.n:
95 self.n_i = 0
96 raise StopIteration
97 t = torch.zeros(self.y_dim)
98 t[self.y_i] = 1
99 ret = (
100 self.C[self.n_i],
101 t,
102 self.X[self.n_i],
103 self.Y[self.n_i, self.y_i].unsqueeze(0),
104 self.n_i,
105 self.y_i,
106 )
107 self.y_i += 1
108 return ret
110 def __len__(self):
111 return self.n * self.y_dim
114class MultitaskUnivariateDataset(Dataset):
115 """
116 Multitask Univariate Dataset
117 """
119 def __next__(self):
120 if self.y_i >= self.y_dim:
121 self.x_i += 1
122 self.y_i = 0
123 if self.x_i >= self.x_dim:
124 self.n_i += 1
125 self.x_i = 0
126 if self.n_i >= self.n:
127 self.n_i = 0
128 raise StopIteration
129 t = torch.zeros(self.x_dim + self.y_dim)
130 t[self.x_i] = 1
131 t[self.x_dim + self.y_i] = 1
132 ret = (
133 self.C[self.n_i],
134 t,
135 self.X[self.n_i, self.x_i].unsqueeze(0),
136 self.Y[self.n_i, self.y_i].unsqueeze(0),
137 self.n_i,
138 self.x_i,
139 self.y_i,
140 )
141 self.y_i += 1
142 return ret
144 def __len__(self):
145 return self.n * self.x_dim * self.y_dim
148class DataIterable(IterableDataset):
149 """Dataset wrapper, required by PyTorch"""
151 def __init__(self, dataset):
152 self.dataset = dataset
154 def __iter__(self):
155 return iter(self.dataset)