Coverage for src / autoencodix / modeling / _classifier.py: 38%
16 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import torch.nn as nn
5class Classifier(nn.Module):
6 """Multi-class classifier for adversarial training in n-modal latent space alignment.
9 Attributes:
10 input_dim: Dimension of the input features.
11 n_modalities: Number of modalities (classes) to classify.
12 n_hidden: Number of hidden units in the classifier network.
14 """
16 def __init__(self, input_dim: int, n_modalities: int, n_hidden: int = 64) -> None:
17 super().__init__()
18 self.input_dim = input_dim
19 self.n_modalities = n_modalities
20 self.classifier = nn.Sequential(
21 nn.Linear(input_dim, n_hidden),
22 nn.ReLU(inplace=False),
23 nn.Dropout(0.1),
24 nn.Linear(n_hidden, n_hidden // 2),
25 nn.ReLU(inplace=False),
26 nn.Dropout(0.1),
27 nn.Linear(n_hidden // 2, n_modalities),
28 )
29 self.apply(self._init_weights)
31 def forward(self, x: torch.Tensor) -> torch.Tensor:
32 """Forward pass through the classifier.
34 Args:
35 x: Input tensor of shape (batch_size, input_dim).
36 Returns:
37 Output tensor of shape (batch_size, n_modalities) representing class scores.
38 """
39 return self.classifier(x)
41 def _init_weights(self, m: nn.Module) -> None:
42 if isinstance(m, nn.Linear):
43 nn.init.xavier_uniform_(m.weight)
44 if m.bias is not None:
45 m.bias.data.fill_(0.01)