Coverage for src / autoencodix / modeling / _layer_factory.py: 88%
24 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch.nn as nn
2from typing import List
5class LayerFactory:
6 """Factory for creating configurable neural network layers."""
8 @staticmethod
9 def get_layer_dimensions(
10 feature_dim: int, latent_dim: int, n_layers: int, enc_factor: float
11 ) -> List[int]:
12 """Calculate progressive layer dimensions.
14 Args:
15 feature_dim: Input feature dimension
16 latent_dim: Target latent dimension
17 n_layers: Number of layers
18 enc_factor: Reduction factor for layer sizes
20 Returns:
21 Calculated layer dimensions
22 """
23 if n_layers == 0:
24 return [feature_dim, latent_dim] # Direct projection from input to latent
26 layer_dimensions = [feature_dim]
27 for _ in range(n_layers):
28 prev_layer_size = layer_dimensions[-1]
29 next_layer_size = max(int(prev_layer_size / enc_factor), latent_dim)
30 layer_dimensions.append(next_layer_size)
31 layer_dimensions.append(latent_dim)
33 return layer_dimensions
35 @staticmethod
36 def create_layer(
37 in_features: int,
38 out_features: int,
39 dropout_p: float = 0.1,
40 last_layer: bool = False,
41 ) -> List[nn.Module]:
42 """Create a configurable layer with optional components.
44 Args:
45 in_features: Input feature dimension
46 out_features: Output feature dimension
47 dropout_p: Dropout probability, by default 0.1
48 last_layer: Flag to skip activation/dropout for final layer, by default False
50 Returns:
51 List of layer components
52 """
53 if last_layer:
54 return [nn.Linear(in_features, out_features)]
56 return [
57 nn.Linear(in_features, out_features),
58 nn.BatchNorm1d(out_features),
59 nn.Dropout(dropout_p),
60 nn.ReLU(),
61 ]
63 @staticmethod
64 def create_maskix_layer(in_features: int, out_features: int, last_layer: bool):
65 if last_layer:
66 return [nn.Linear(in_features, out_features)]
67 return [
68 nn.Linear(in_features, out_features),
69 nn.LayerNorm(out_features),
70 nn.Mish(inplace=True),
71 ]