Coverage for contextualized/modules.py: 100%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2PyTorch modules which are used as building blocks of Contextualized models. 

3""" 

4 

5import torch 

6from torch import nn 

7 

8from contextualized.functions import identity_link 

9 

10 

11class SoftSelect(nn.Module): 

12 """ 

13 Parameter sharing for multiple context encoders: 

14 Batched computation for mapping many subtypes onto d-dimensional archetypes 

15 """ 

16 

17 def __init__(self, in_dims, out_shape): 

18 super().__init__() 

19 self.in_dims = in_dims 

20 self.out_shape = out_shape 

21 init_mat = torch.rand(list(out_shape) + list(in_dims)) * 2e-2 - 1e-2 

22 self.archetypes = nn.parameter.Parameter(init_mat, requires_grad=True) 

23 

24 def forward(self, *batch_weights): 

25 """Torch Forward pass.""" 

26 batch_size = batch_weights[0].shape[0] 

27 expand_dims = [batch_size] + [-1 for _ in range(len(self.archetypes.shape))] 

28 batch_archetypes = self.archetypes.unsqueeze(0).expand(expand_dims) 

29 for batch_w in batch_weights[::-1]: 

30 batch_w = batch_w.unsqueeze(-1) 

31 empty_dims = len(batch_archetypes.shape) - len(batch_w.shape) 

32 for _ in range(empty_dims): 

33 batch_w = batch_w.unsqueeze(1) 

34 batch_archetypes = torch.matmul(batch_archetypes, batch_w).squeeze(-1) 

35 return batch_archetypes 

36 

37 def _cycle_dims(self, tensor, n_steps): 

38 """ 

39 Cycle tensor dimensions from front to back for n steps 

40 """ 

41 for _ in range(n_steps): 

42 tensor = tensor.unsqueeze(0).transpose(0, -1).squeeze(-1) 

43 return tensor 

44 

45 def get_archetypes(self): 

46 """ 

47 Returns archetype parameters: (*in_dims, *out_shape) 

48 """ 

49 return self._cycle_dims(self.archetypes, len(self.in_dims)) 

50 

51 def set_archetypes(self, archetypes): 

52 """ 

53 Sets archetype parameters 

54 

55 Requires archetypes.shape == (*in_dims, *out_shape) 

56 """ 

57 self.archetypes = nn.parameter.Parameter( 

58 self._cycle_dims(archetypes, len(self.out_shape)), requires_grad=True 

59 ) 

60 

61 

62class Explainer(SoftSelect): 

63 """ 

64 2D subtype-archetype parameter sharing 

65 """ 

66 

67 def __init__(self, k, out_shape): 

68 super().__init__((k,), out_shape) 

69 

70 

71class MLP(nn.Module): 

72 """ 

73 Multi-layer perceptron 

74 """ 

75 

76 def __init__( 

77 self, 

78 input_dim, 

79 output_dim, 

80 width, 

81 layers, 

82 activation=nn.ReLU, 

83 link_fn=identity_link, 

84 ): 

85 super().__init__() 

86 if layers > 0: 

87 mlp_layers = [nn.Linear(input_dim, width), activation()] 

88 for _ in range(layers - 1): 

89 mlp_layers += [nn.Linear(width, width), activation()] 

90 mlp_layers.append(nn.Linear(width, output_dim)) 

91 else: # Linear encoder 

92 mlp_layers = [nn.Linear(input_dim, output_dim)] 

93 self.mlp = nn.Sequential(*mlp_layers) 

94 self.link_fn = link_fn 

95 

96 def forward(self, X): 

97 """Torch Forward pass.""" 

98 ret = self.mlp(X) 

99 return self.link_fn(ret) 

100 

101 

102class NGAM(nn.Module): 

103 """ 

104 Neural generalized additive model 

105 """ 

106 

107 def __init__( 

108 self, 

109 input_dim, 

110 output_dim, 

111 width, 

112 layers, 

113 activation=nn.ReLU, 

114 link_fn=identity_link, 

115 ): 

116 super().__init__() 

117 self.intput_dim = input_dim 

118 self.output_dim = output_dim 

119 self.nams = nn.ModuleList( 

120 [ 

121 MLP( 

122 1, 

123 output_dim, 

124 width, 

125 layers, 

126 activation=activation, 

127 link_fn=identity_link, 

128 ) 

129 for _ in range(input_dim) 

130 ] 

131 ) 

132 self.link_fn = link_fn 

133 

134 def forward(self, X): 

135 """Torch Forward pass.""" 

136 ret = self.nams[0](X[:, 0].unsqueeze(-1)) 

137 for i, nam in enumerate(self.nams[1:]): 

138 ret += nam(X[:, i].unsqueeze(-1)) 

139 return self.link_fn(ret) 

140 

141 

142class Linear(nn.Module): 

143 """ 

144 Linear encoder 

145 """ 

146 

147 def __init__(self, input_dim, output_dim): 

148 super().__init__() 

149 self.linear = MLP( 

150 input_dim, output_dim, width=output_dim, layers=0, activation=None 

151 ) 

152 

153 def forward(self, X): 

154 """Torch Forward pass.""" 

155 return self.linear(X) 

156 

157 

158ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear}