1# =====================================================
2# Imports
3# =====================================================
4from enum import Enum
5from typing import Annotated
6
7import numpy as np
8from loguru import logger
9from numpy._typing import NDArray
10from pydantic import Field
11
12from mt_metadata.features.weights.monotonic_weight_kernel import MonotonicWeightKernel
13
14
15# =====================================================
16class ThresholdEnum(str, Enum):
17 low_cut = "low cut"
18 high_cut = "high cut"
19
20
21class ActivationStyleEnum(str, Enum):
22 sigmoid = "sigmoid"
23 hard_sigmoid = "hard_sigmoid"
24 tanh = "tanh"
25 hard_tanh = "hard_tanh"
26
27
28class ActivationMonotonicWeightKernel(MonotonicWeightKernel):
29 threshold: Annotated[
30 ThresholdEnum,
31 Field(
32 default=ThresholdEnum.low_cut,
33 description="Which side of a threshold should be downweighted.",
34 alias=None,
35 json_schema_extra={
36 "units": None,
37 "required": True,
38 "examples": ["low cut"],
39 },
40 ),
41 ]
42
43 activation_style: Annotated[
44 ActivationStyleEnum,
45 Field(
46 default=ActivationStyleEnum.sigmoid,
47 description="Tapering/activation function to use between transition bounds.",
48 alias=None,
49 json_schema_extra={
50 "units": None,
51 "required": True,
52 "examples": ["tanh"],
53 },
54 ),
55 ]
56
57 steepness: Annotated[
58 float, # the definition had default as None, can we set it to 1?
59 Field(
60 default=1.0,
61 description="Controls the sharpness of the activation transition.",
62 alias=None,
63 json_schema_extra={
64 "units": None,
65 "required": False,
66 "examples": ["10"],
67 },
68 ),
69 ]
70
71 def _normalize(self, values: NDArray) -> NDArray:
72 """
73 Normalize input values to the [0, 1] interval for activation kernels, supporting infinite bounds and respecting threshold direction.
74
75 For finite bounds, applies linear normalization and reverses for 'high cut'.
76 For infinite bounds, subclasses should define behavior, but this implementation will map all values to 0.5.
77 """
78 lb = float(self.transition_lower_bound)
79 ub = float(self.transition_upper_bound)
80 values = np.asarray(values)
81 direction = getattr(self, "threshold", "low cut")
82 # Both bounds finite
83 if np.isfinite(lb) and np.isfinite(ub):
84 x = (values - lb) / (ub - lb)
85 if direction == "high cut":
86 x = 1 - x
87 return np.clip(x, 0, 1)
88 # Infinite bounds: fallback (could be extended for custom behavior)
89 msg = "ActivationMonotonicWeightKernel only supports finite transition bounds. "
90 logger.warning(msg + "Returning 0.5 for all values.")
91 return np.full_like(values, 0.5)
92
93 def evaluate(self, values: NDArray) -> NDArray:
94 """
95 Evaluate the activation function for the given input values.
96
97 Parameters
98 ----------
99 values : NDArray
100 Input values to be evaluated.
101
102 Returns
103 -------
104 NDArray
105 Evaluated activation values.
106
107 Raises
108 ------
109 ValueError
110 If the activation style is not recognized.
111 """
112
113 x = self._normalize(values)
114 activation_style = self.activation_style
115
116 if activation_style == "sigmoid":
117 y = 1 / (
118 1 + np.exp(-float(self.steepness) * (x - 0.5))
119 ) # what happens if steepness is None?
120 elif activation_style == "hard_sigmoid":
121 y = np.clip(0.2 * (x - 0.5) + 0.5, 0, 1)
122 elif activation_style == "tanh":
123 y = 0.5 * (np.tanh(float(self.steepness) * (x - 0.5)) + 1)
124 elif activation_style == "hard_tanh":
125 y = np.clip(x, 0, 1)
126 else:
127 raise ValueError(f"Unsupported activation style: {activation_style}")
128
129 return y