Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ features \ weights \ feature_weight_spec.py: 77%

93 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from enum import Enum 

5from typing import Annotated 

6 

7import numpy as np 

8from loguru import logger 

9from pydantic import Field, field_validator, model_validator, ValidationInfo 

10 

11from mt_metadata.base import MetadataBase 

12from mt_metadata.features.coherence import Coherence 

13from mt_metadata.features.fc_coherence import FCCoherence 

14from mt_metadata.features.feature import Feature 

15from mt_metadata.features.striding_window_coherence import StridingWindowCoherence 

16from mt_metadata.features.weights.activation_monotonic_weight_kernel import ( 

17 ActivationMonotonicWeightKernel, 

18) 

19from mt_metadata.features.weights.monotonic_weight_kernel import MonotonicWeightKernel 

20from mt_metadata.features.weights.taper_monotonic_weight_kernel import ( 

21 TaperMonotonicWeightKernel, 

22) 

23 

24 

25## for new features import and add to this dictionary. 

26feature_classes = { 

27 "base": Feature, 

28 "coherence": Coherence, 

29 "fc_coherence": FCCoherence, 

30 "striding_window_coherence": StridingWindowCoherence, 

31} 

32 

33weight_classes = { 

34 "monotonic": MonotonicWeightKernel, 

35 "taper": TaperMonotonicWeightKernel, 

36 "activation": ActivationMonotonicWeightKernel, 

37} 

38 

39 

40# ===================================================== 

41class FeatureNameEnum(str, Enum): 

42 coherence = "coherence" 

43 multiple_coherence = "multiple coherence" 

44 

45 

46class FeatureWeightSpec(MetadataBase): 

47 feature_name: Annotated[ 

48 FeatureNameEnum, 

49 Field( 

50 default="", 

51 description="The name of the feature to evaluate (e.g., coherence, impedance_ratio).", 

52 alias=None, 

53 json_schema_extra={ 

54 "units": None, 

55 "required": True, 

56 "examples": ["coherence"], 

57 }, 

58 ), 

59 ] 

60 

61 feature: Annotated[ 

62 dict | Feature | Coherence | FCCoherence | StridingWindowCoherence, 

63 Field( 

64 default_factory=Feature, # type: ignore 

65 description="The feature specification.", 

66 json_schema_extra={ 

67 "units": None, 

68 "required": True, 

69 "examples": [{"type": "coherence"}], 

70 }, 

71 ), 

72 ] 

73 

74 weight_kernels: Annotated[ 

75 list[ 

76 MonotonicWeightKernel 

77 | TaperMonotonicWeightKernel 

78 | ActivationMonotonicWeightKernel 

79 ], 

80 Field( 

81 default_factory=list, 

82 description="List of weight kernel specification.", 

83 json_schema_extra={ 

84 "units": None, 

85 "required": True, 

86 "examples": [{"type": "monotonic"}], 

87 }, 

88 ), 

89 ] 

90 

91 @model_validator(mode="before") 

92 @classmethod 

93 def pre_process_feature(cls, data: dict) -> dict: 

94 """Pre-process the feature dict to ensure correct class is instantiated.""" 

95 if isinstance(data, dict) and "feature" in data: 

96 feature_data = data["feature"] 

97 # Handle nested feature dict wrapping 

98 while isinstance(feature_data, dict) and "feature" in feature_data: 

99 feature_data = feature_data["feature"] 

100 

101 if isinstance(feature_data, dict): 

102 feature_name = feature_data.get("name") 

103 logger.debug(f"pre_process_feature: feature_name={feature_name}") 

104 if feature_name in feature_classes: 

105 feature_cls = feature_classes[feature_name] 

106 logger.debug( 

107 f"pre_process_feature: Creating {feature_cls.__name__} instance" 

108 ) 

109 data["feature"] = feature_cls(**feature_data) 

110 else: 

111 logger.warning( 

112 f"pre_process_feature: Unknown feature name '{feature_name}', using Feature" 

113 ) 

114 return data 

115 

116 @field_validator("feature", mode="before") 

117 @classmethod 

118 def validate_feature( 

119 cls, value, info: ValidationInfo 

120 ) -> Feature | Coherence | FCCoherence | StridingWindowCoherence | None: 

121 """Validate the feature field to ensure it matches the feature_name.""" 

122 logger.debug( 

123 f"validate_feature called with value type: {type(value)}, value: {value}" 

124 ) 

125 while ( 

126 isinstance(value, dict) 

127 and "feature" in value 

128 and isinstance(value["feature"], dict) 

129 ): 

130 logger.debug(f"Unwrapping nested feature dict") 

131 value = value["feature"] 

132 if isinstance(value, dict): 

133 feature_name = value.get("name") 

134 # Import here to avoid circular import at module level 

135 logger.debug( 

136 f"Feature setter: feature_name={feature_name}, value keys={value.keys()}" 

137 ) # DEBUG 

138 if not isinstance(feature_name, str) or feature_name not in feature_classes: 

139 logger.warning( 

140 f"Feature name '{feature_name}' not in feature_classes, using base Feature" 

141 ) 

142 feature_cls = Feature 

143 else: 

144 feature_cls = feature_classes[feature_name] 

145 logger.debug(f"Selected feature class: {feature_cls.__name__}") 

146 logger.debug( 

147 f"Feature setter: instantiated {feature_cls.__class__}" 

148 ) # DEBUG 

149 return feature_cls(**value) 

150 elif isinstance( 

151 value, (Feature, Coherence, FCCoherence, StridingWindowCoherence) 

152 ): 

153 logger.debug( 

154 f"Feature setter: set directly to {type(value).__name__}" 

155 ) # DEBUG 

156 return value 

157 else: 

158 logger.warning( 

159 f"Feature value is neither dict nor Feature instance: {type(value)}" 

160 ) 

161 return None 

162 

163 @field_validator("weight_kernels", mode="before") 

164 @classmethod 

165 def validate_weight_kernels( 

166 cls, value, info: ValidationInfo 

167 ) -> list[ 

168 MonotonicWeightKernel 

169 | TaperMonotonicWeightKernel 

170 | ActivationMonotonicWeightKernel 

171 ]: 

172 """Validate the weight_kernels field to ensure proper initialization.""" 

173 if not isinstance(value, list): 

174 value = [value] 

175 kernels = [] 

176 for item in value: 

177 if isinstance(item, dict) and "weight_kernel" in item: 

178 item = item["weight_kernel"] 

179 if isinstance(item, dict): 

180 # Use the 'style' field to determine which kernel class to use 

181 style = str(item.get("style", "")) 

182 if style in weight_classes: 

183 try: 

184 kernels.append(weight_classes[style](**item)) 

185 except Exception as e: 

186 msg = ( 

187 f"Failed to create weight kernel with style '{style}': {e}" 

188 ) 

189 logger.warning(msg) 

190 else: 

191 # Fallback to weight_type for backward compatibility 

192 weight_type = str(item.get("weight_type", "")) 

193 if weight_type in weight_classes: 

194 try: 

195 kernels.append(weight_classes[weight_type](**item)) 

196 except Exception as e: 

197 msg = f"Failed to create weight kernel with weight_type '{weight_type}': {e}" 

198 logger.warning(msg) 

199 else: 

200 msg = f"Neither style '{style}' nor weight_type '{weight_type}' recognized -- skipping" 

201 logger.warning(msg) 

202 

203 elif isinstance( 

204 item, 

205 ( 

206 MonotonicWeightKernel, 

207 TaperMonotonicWeightKernel, 

208 ActivationMonotonicWeightKernel, 

209 ), 

210 ): 

211 kernels.append(item) 

212 else: 

213 raise TypeError(f"Invalid type for weight_kernel: {type(item)}") 

214 return kernels 

215 

216 def evaluate(self, feature_values): 

217 """ 

218 Evaluate this feature's weighting based on the list of kernels. 

219 

220 Parameters 

221 ---------- 

222 feature_values : np.ndarray or float 

223 The computed values for this feature. 

224 

225 Returns 

226 ------- 

227 combined_weight : np.ndarray or float 

228 The combined weight from all kernels (e.g., multiplied together). 

229 """ 

230 

231 weights = [kernel.evaluate(feature_values) for kernel in self.weight_kernels] 

232 return np.prod(weights, axis=0) if weights else 1.0