Coverage for contextualized/regression/datasets.py: 93%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Data generators used for Contextualized regression training. 

3""" 

4 

5from abc import abstractmethod 

6import torch 

7from torch.utils.data import IterableDataset 

8 

9 

10class Dataset: 

11 """Superclass for datastreams (iterators) used to train contextualized.regression models""" 

12 

13 def __init__(self, C, X, Y, dtype=torch.float): 

14 self.C = torch.tensor(C, dtype=dtype) 

15 self.X = torch.tensor(X, dtype=dtype) 

16 self.Y = torch.tensor(Y, dtype=dtype) 

17 self.n_i = 0 

18 self.x_i = 0 

19 self.y_i = 0 

20 self.n = C.shape[0] 

21 self.c_dim = C.shape[-1] 

22 self.x_dim = X.shape[-1] 

23 self.y_dim = Y.shape[-1] 

24 self.dtype = dtype 

25 

26 def __iter__(self): 

27 self.n_i = 0 

28 self.x_i = 0 

29 self.y_i = 0 

30 return self 

31 

32 @abstractmethod 

33 def __next__(self): 

34 pass 

35 

36 @abstractmethod 

37 def __len__(self): 

38 pass 

39 

40 

41class MultivariateDataset(Dataset): 

42 """ 

43 Simple multivariate dataset with context, predictors, and outcomes. 

44 """ 

45 

46 def __next__(self): 

47 if self.n_i >= self.n: 

48 self.n_i = 0 

49 raise StopIteration 

50 ret = ( 

51 self.C[self.n_i], 

52 self.X[self.n_i].expand(self.y_dim, -1), 

53 self.Y[self.n_i].unsqueeze(-1), 

54 self.n_i, 

55 ) 

56 self.n_i += 1 

57 return ret 

58 

59 def __len__(self): 

60 return self.n 

61 

62 

63class UnivariateDataset(Dataset): 

64 """ 

65 Simple univariate dataset with context, predictors, and one outcome. 

66 """ 

67 

68 def __next__(self): 

69 if self.n_i >= self.n: 

70 self.n_i = 0 

71 raise StopIteration 

72 ret = ( 

73 self.C[self.n_i], 

74 self.X[self.n_i].expand(self.y_dim, -1).unsqueeze(-1), 

75 self.Y[self.n_i].expand(self.x_dim, -1).T.unsqueeze(-1), 

76 self.n_i, 

77 ) 

78 self.n_i += 1 

79 return ret 

80 

81 def __len__(self): 

82 return self.n 

83 

84 

85class MultitaskMultivariateDataset(Dataset): 

86 """ 

87 Multi-task Multivariate Dataset. 

88 """ 

89 

90 def __next__(self): 

91 if self.y_i >= self.y_dim: 

92 self.n_i += 1 

93 self.y_i = 0 

94 if self.n_i >= self.n: 

95 self.n_i = 0 

96 raise StopIteration 

97 t = torch.zeros(self.y_dim) 

98 t[self.y_i] = 1 

99 ret = ( 

100 self.C[self.n_i], 

101 t, 

102 self.X[self.n_i], 

103 self.Y[self.n_i, self.y_i].unsqueeze(0), 

104 self.n_i, 

105 self.y_i, 

106 ) 

107 self.y_i += 1 

108 return ret 

109 

110 def __len__(self): 

111 return self.n * self.y_dim 

112 

113 

114class MultitaskUnivariateDataset(Dataset): 

115 """ 

116 Multitask Univariate Dataset 

117 """ 

118 

119 def __next__(self): 

120 if self.y_i >= self.y_dim: 

121 self.x_i += 1 

122 self.y_i = 0 

123 if self.x_i >= self.x_dim: 

124 self.n_i += 1 

125 self.x_i = 0 

126 if self.n_i >= self.n: 

127 self.n_i = 0 

128 raise StopIteration 

129 t = torch.zeros(self.x_dim + self.y_dim) 

130 t[self.x_i] = 1 

131 t[self.x_dim + self.y_i] = 1 

132 ret = ( 

133 self.C[self.n_i], 

134 t, 

135 self.X[self.n_i, self.x_i].unsqueeze(0), 

136 self.Y[self.n_i, self.y_i].unsqueeze(0), 

137 self.n_i, 

138 self.x_i, 

139 self.y_i, 

140 ) 

141 self.y_i += 1 

142 return ret 

143 

144 def __len__(self): 

145 return self.n * self.x_dim * self.y_dim 

146 

147 

148class DataIterable(IterableDataset): 

149 """Dataset wrapper, required by PyTorch""" 

150 

151 def __init__(self, dataset): 

152 self.dataset = dataset 

153 

154 def __iter__(self): 

155 return iter(self.dataset)