Coverage for src / autoencodix / modeling / _layer_factory.py: 88%

24 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch.nn as nn 

2from typing import List 

3 

4 

5class LayerFactory: 

6 """Factory for creating configurable neural network layers.""" 

7 

8 @staticmethod 

9 def get_layer_dimensions( 

10 feature_dim: int, latent_dim: int, n_layers: int, enc_factor: float 

11 ) -> List[int]: 

12 """Calculate progressive layer dimensions. 

13 

14 Args: 

15 feature_dim: Input feature dimension 

16 latent_dim: Target latent dimension 

17 n_layers: Number of layers 

18 enc_factor: Reduction factor for layer sizes 

19 

20 Returns: 

21 Calculated layer dimensions 

22 """ 

23 if n_layers == 0: 

24 return [feature_dim, latent_dim] # Direct projection from input to latent 

25 

26 layer_dimensions = [feature_dim] 

27 for _ in range(n_layers): 

28 prev_layer_size = layer_dimensions[-1] 

29 next_layer_size = max(int(prev_layer_size / enc_factor), latent_dim) 

30 layer_dimensions.append(next_layer_size) 

31 layer_dimensions.append(latent_dim) 

32 

33 return layer_dimensions 

34 

35 @staticmethod 

36 def create_layer( 

37 in_features: int, 

38 out_features: int, 

39 dropout_p: float = 0.1, 

40 last_layer: bool = False, 

41 ) -> List[nn.Module]: 

42 """Create a configurable layer with optional components. 

43 

44 Args: 

45 in_features: Input feature dimension 

46 out_features: Output feature dimension 

47 dropout_p: Dropout probability, by default 0.1 

48 last_layer: Flag to skip activation/dropout for final layer, by default False 

49 

50 Returns: 

51 List of layer components 

52 """ 

53 if last_layer: 

54 return [nn.Linear(in_features, out_features)] 

55 

56 return [ 

57 nn.Linear(in_features, out_features), 

58 nn.BatchNorm1d(out_features), 

59 nn.Dropout(dropout_p), 

60 nn.ReLU(), 

61 ] 

62 

63 @staticmethod 

64 def create_maskix_layer(in_features: int, out_features: int, last_layer: bool): 

65 if last_layer: 

66 return [nn.Linear(in_features, out_features)] 

67 return [ 

68 nn.Linear(in_features, out_features), 

69 nn.LayerNorm(out_features), 

70 nn.Mish(inplace=True), 

71 ]