Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ features \ weights \ feature_weight_spec.py: 77%
93 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1# =====================================================
2# Imports
3# =====================================================
4from enum import Enum
5from typing import Annotated
7import numpy as np
8from loguru import logger
9from pydantic import Field, field_validator, model_validator, ValidationInfo
11from mt_metadata.base import MetadataBase
12from mt_metadata.features.coherence import Coherence
13from mt_metadata.features.fc_coherence import FCCoherence
14from mt_metadata.features.feature import Feature
15from mt_metadata.features.striding_window_coherence import StridingWindowCoherence
16from mt_metadata.features.weights.activation_monotonic_weight_kernel import (
17 ActivationMonotonicWeightKernel,
18)
19from mt_metadata.features.weights.monotonic_weight_kernel import MonotonicWeightKernel
20from mt_metadata.features.weights.taper_monotonic_weight_kernel import (
21 TaperMonotonicWeightKernel,
22)
25## for new features import and add to this dictionary.
26feature_classes = {
27 "base": Feature,
28 "coherence": Coherence,
29 "fc_coherence": FCCoherence,
30 "striding_window_coherence": StridingWindowCoherence,
31}
33weight_classes = {
34 "monotonic": MonotonicWeightKernel,
35 "taper": TaperMonotonicWeightKernel,
36 "activation": ActivationMonotonicWeightKernel,
37}
40# =====================================================
41class FeatureNameEnum(str, Enum):
42 coherence = "coherence"
43 multiple_coherence = "multiple coherence"
46class FeatureWeightSpec(MetadataBase):
47 feature_name: Annotated[
48 FeatureNameEnum,
49 Field(
50 default="",
51 description="The name of the feature to evaluate (e.g., coherence, impedance_ratio).",
52 alias=None,
53 json_schema_extra={
54 "units": None,
55 "required": True,
56 "examples": ["coherence"],
57 },
58 ),
59 ]
61 feature: Annotated[
62 dict | Feature | Coherence | FCCoherence | StridingWindowCoherence,
63 Field(
64 default_factory=Feature, # type: ignore
65 description="The feature specification.",
66 json_schema_extra={
67 "units": None,
68 "required": True,
69 "examples": [{"type": "coherence"}],
70 },
71 ),
72 ]
74 weight_kernels: Annotated[
75 list[
76 MonotonicWeightKernel
77 | TaperMonotonicWeightKernel
78 | ActivationMonotonicWeightKernel
79 ],
80 Field(
81 default_factory=list,
82 description="List of weight kernel specification.",
83 json_schema_extra={
84 "units": None,
85 "required": True,
86 "examples": [{"type": "monotonic"}],
87 },
88 ),
89 ]
91 @model_validator(mode="before")
92 @classmethod
93 def pre_process_feature(cls, data: dict) -> dict:
94 """Pre-process the feature dict to ensure correct class is instantiated."""
95 if isinstance(data, dict) and "feature" in data:
96 feature_data = data["feature"]
97 # Handle nested feature dict wrapping
98 while isinstance(feature_data, dict) and "feature" in feature_data:
99 feature_data = feature_data["feature"]
101 if isinstance(feature_data, dict):
102 feature_name = feature_data.get("name")
103 logger.debug(f"pre_process_feature: feature_name={feature_name}")
104 if feature_name in feature_classes:
105 feature_cls = feature_classes[feature_name]
106 logger.debug(
107 f"pre_process_feature: Creating {feature_cls.__name__} instance"
108 )
109 data["feature"] = feature_cls(**feature_data)
110 else:
111 logger.warning(
112 f"pre_process_feature: Unknown feature name '{feature_name}', using Feature"
113 )
114 return data
116 @field_validator("feature", mode="before")
117 @classmethod
118 def validate_feature(
119 cls, value, info: ValidationInfo
120 ) -> Feature | Coherence | FCCoherence | StridingWindowCoherence | None:
121 """Validate the feature field to ensure it matches the feature_name."""
122 logger.debug(
123 f"validate_feature called with value type: {type(value)}, value: {value}"
124 )
125 while (
126 isinstance(value, dict)
127 and "feature" in value
128 and isinstance(value["feature"], dict)
129 ):
130 logger.debug(f"Unwrapping nested feature dict")
131 value = value["feature"]
132 if isinstance(value, dict):
133 feature_name = value.get("name")
134 # Import here to avoid circular import at module level
135 logger.debug(
136 f"Feature setter: feature_name={feature_name}, value keys={value.keys()}"
137 ) # DEBUG
138 if not isinstance(feature_name, str) or feature_name not in feature_classes:
139 logger.warning(
140 f"Feature name '{feature_name}' not in feature_classes, using base Feature"
141 )
142 feature_cls = Feature
143 else:
144 feature_cls = feature_classes[feature_name]
145 logger.debug(f"Selected feature class: {feature_cls.__name__}")
146 logger.debug(
147 f"Feature setter: instantiated {feature_cls.__class__}"
148 ) # DEBUG
149 return feature_cls(**value)
150 elif isinstance(
151 value, (Feature, Coherence, FCCoherence, StridingWindowCoherence)
152 ):
153 logger.debug(
154 f"Feature setter: set directly to {type(value).__name__}"
155 ) # DEBUG
156 return value
157 else:
158 logger.warning(
159 f"Feature value is neither dict nor Feature instance: {type(value)}"
160 )
161 return None
163 @field_validator("weight_kernels", mode="before")
164 @classmethod
165 def validate_weight_kernels(
166 cls, value, info: ValidationInfo
167 ) -> list[
168 MonotonicWeightKernel
169 | TaperMonotonicWeightKernel
170 | ActivationMonotonicWeightKernel
171 ]:
172 """Validate the weight_kernels field to ensure proper initialization."""
173 if not isinstance(value, list):
174 value = [value]
175 kernels = []
176 for item in value:
177 if isinstance(item, dict) and "weight_kernel" in item:
178 item = item["weight_kernel"]
179 if isinstance(item, dict):
180 # Use the 'style' field to determine which kernel class to use
181 style = str(item.get("style", ""))
182 if style in weight_classes:
183 try:
184 kernels.append(weight_classes[style](**item))
185 except Exception as e:
186 msg = (
187 f"Failed to create weight kernel with style '{style}': {e}"
188 )
189 logger.warning(msg)
190 else:
191 # Fallback to weight_type for backward compatibility
192 weight_type = str(item.get("weight_type", ""))
193 if weight_type in weight_classes:
194 try:
195 kernels.append(weight_classes[weight_type](**item))
196 except Exception as e:
197 msg = f"Failed to create weight kernel with weight_type '{weight_type}': {e}"
198 logger.warning(msg)
199 else:
200 msg = f"Neither style '{style}' nor weight_type '{weight_type}' recognized -- skipping"
201 logger.warning(msg)
203 elif isinstance(
204 item,
205 (
206 MonotonicWeightKernel,
207 TaperMonotonicWeightKernel,
208 ActivationMonotonicWeightKernel,
209 ),
210 ):
211 kernels.append(item)
212 else:
213 raise TypeError(f"Invalid type for weight_kernel: {type(item)}")
214 return kernels
216 def evaluate(self, feature_values):
217 """
218 Evaluate this feature's weighting based on the list of kernels.
220 Parameters
221 ----------
222 feature_values : np.ndarray or float
223 The computed values for this feature.
225 Returns
226 -------
227 combined_weight : np.ndarray or float
228 The combined weight from all kernels (e.g., multiplied together).
229 """
231 weights = [kernel.evaluate(feature_values) for kernel in self.weight_kernels]
232 return np.prod(weights, axis=0) if weights else 1.0