Coverage for src / autoencodix / modeling / _classifier.py: 38%

16 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import torch.nn as nn 

3 

4 

5class Classifier(nn.Module): 

6 """Multi-class classifier for adversarial training in n-modal latent space alignment. 

7 

8 

9 Attributes: 

10 input_dim: Dimension of the input features. 

11 n_modalities: Number of modalities (classes) to classify. 

12 n_hidden: Number of hidden units in the classifier network. 

13 

14 """ 

15 

16 def __init__(self, input_dim: int, n_modalities: int, n_hidden: int = 64) -> None: 

17 super().__init__() 

18 self.input_dim = input_dim 

19 self.n_modalities = n_modalities 

20 self.classifier = nn.Sequential( 

21 nn.Linear(input_dim, n_hidden), 

22 nn.ReLU(inplace=False), 

23 nn.Dropout(0.1), 

24 nn.Linear(n_hidden, n_hidden // 2), 

25 nn.ReLU(inplace=False), 

26 nn.Dropout(0.1), 

27 nn.Linear(n_hidden // 2, n_modalities), 

28 ) 

29 self.apply(self._init_weights) 

30 

31 def forward(self, x: torch.Tensor) -> torch.Tensor: 

32 """Forward pass through the classifier. 

33 

34 Args: 

35 x: Input tensor of shape (batch_size, input_dim). 

36 Returns: 

37 Output tensor of shape (batch_size, n_modalities) representing class scores. 

38 """ 

39 return self.classifier(x) 

40 

41 def _init_weights(self, m: nn.Module) -> None: 

42 if isinstance(m, nn.Linear): 

43 nn.init.xavier_uniform_(m.weight) 

44 if m.bias is not None: 

45 m.bias.data.fill_(0.01)