Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ features \ weights \ activation_monotonic_weight_kernel.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from enum import Enum 

5from typing import Annotated 

6 

7import numpy as np 

8from loguru import logger 

9from numpy._typing import NDArray 

10from pydantic import Field 

11 

12from mt_metadata.features.weights.monotonic_weight_kernel import MonotonicWeightKernel 

13 

14 

15# ===================================================== 

16class ThresholdEnum(str, Enum): 

17 low_cut = "low cut" 

18 high_cut = "high cut" 

19 

20 

21class ActivationStyleEnum(str, Enum): 

22 sigmoid = "sigmoid" 

23 hard_sigmoid = "hard_sigmoid" 

24 tanh = "tanh" 

25 hard_tanh = "hard_tanh" 

26 

27 

28class ActivationMonotonicWeightKernel(MonotonicWeightKernel): 

29 threshold: Annotated[ 

30 ThresholdEnum, 

31 Field( 

32 default=ThresholdEnum.low_cut, 

33 description="Which side of a threshold should be downweighted.", 

34 alias=None, 

35 json_schema_extra={ 

36 "units": None, 

37 "required": True, 

38 "examples": ["low cut"], 

39 }, 

40 ), 

41 ] 

42 

43 activation_style: Annotated[ 

44 ActivationStyleEnum, 

45 Field( 

46 default=ActivationStyleEnum.sigmoid, 

47 description="Tapering/activation function to use between transition bounds.", 

48 alias=None, 

49 json_schema_extra={ 

50 "units": None, 

51 "required": True, 

52 "examples": ["tanh"], 

53 }, 

54 ), 

55 ] 

56 

57 steepness: Annotated[ 

58 float, # the definition had default as None, can we set it to 1? 

59 Field( 

60 default=1.0, 

61 description="Controls the sharpness of the activation transition.", 

62 alias=None, 

63 json_schema_extra={ 

64 "units": None, 

65 "required": False, 

66 "examples": ["10"], 

67 }, 

68 ), 

69 ] 

70 

71 def _normalize(self, values: NDArray) -> NDArray: 

72 """ 

73 Normalize input values to the [0, 1] interval for activation kernels, supporting infinite bounds and respecting threshold direction. 

74 

75 For finite bounds, applies linear normalization and reverses for 'high cut'. 

76 For infinite bounds, subclasses should define behavior, but this implementation will map all values to 0.5. 

77 """ 

78 lb = float(self.transition_lower_bound) 

79 ub = float(self.transition_upper_bound) 

80 values = np.asarray(values) 

81 direction = getattr(self, "threshold", "low cut") 

82 # Both bounds finite 

83 if np.isfinite(lb) and np.isfinite(ub): 

84 x = (values - lb) / (ub - lb) 

85 if direction == "high cut": 

86 x = 1 - x 

87 return np.clip(x, 0, 1) 

88 # Infinite bounds: fallback (could be extended for custom behavior) 

89 msg = "ActivationMonotonicWeightKernel only supports finite transition bounds. " 

90 logger.warning(msg + "Returning 0.5 for all values.") 

91 return np.full_like(values, 0.5) 

92 

93 def evaluate(self, values: NDArray) -> NDArray: 

94 """ 

95 Evaluate the activation function for the given input values. 

96 

97 Parameters 

98 ---------- 

99 values : NDArray 

100 Input values to be evaluated. 

101 

102 Returns 

103 ------- 

104 NDArray 

105 Evaluated activation values. 

106 

107 Raises 

108 ------ 

109 ValueError 

110 If the activation style is not recognized. 

111 """ 

112 

113 x = self._normalize(values) 

114 activation_style = self.activation_style 

115 

116 if activation_style == "sigmoid": 

117 y = 1 / ( 

118 1 + np.exp(-float(self.steepness) * (x - 0.5)) 

119 ) # what happens if steepness is None? 

120 elif activation_style == "hard_sigmoid": 

121 y = np.clip(0.2 * (x - 0.5) + 0.5, 0, 1) 

122 elif activation_style == "tanh": 

123 y = 0.5 * (np.tanh(float(self.steepness) * (x - 0.5)) + 1) 

124 elif activation_style == "hard_tanh": 

125 y = np.clip(x, 0, 1) 

126 else: 

127 raise ValueError(f"Unsupported activation style: {activation_style}") 

128 

129 return y