Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ features \ weights \ channel_weight_spec.py: 97%
79 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-10 00:11 -0800
1"""
2Container for weighting strategy to apply to a single tf estimation
3having a single output channel (usually one of "ex", "ey", "hz").
5candidate data structure is stored in test_helpers/channel_weight_specs_example.json
7Candidate names: processing_weights, feature_weights, channel_weights_spec, channel_weighting
9Notes, and doc for weights PR.
11channel_weight_specs is a candidate name for the json block like the following:
12>>> diff processing_configuration_template.json test_processing_config_with_weights_block.json
13(Another candidate name could be `processing_weights`, or `weights`, but the final nomenclature
14can be sorted out after there is a functional prototype with the appropriate structure.)
17This block is basically a dict that maps an output channel name to a ChannelWeightSpec (CWS) object.
19There are at least three places we would like to be able to plug in such a dict to the processing flow.
201. At the frequency_band level, so that each band can be associated with a specialty CWS
212. At the decimation_level level, so that all bands in a GIB have a common, default.
223. At a high level, so that all processing uses them.
23TAI: In future, hopefully we could insert a custom CWS for a specific band, but leave
24all other bands to use the DecimationLevel default CWS, for example. i.e. the CWS can
25be defined for different scopes.
27TODO FIXME: IN mt_metadata/transfer_functions/processing/auaora/processing.py
28when you output a json, it looks like the `decimations` level should be named:
29`decimation_levels` instead.
31The general model I'll try to follow will be to open an itearable of objects
32with a plural of the object name. For example, the processing block called "bands"
33follows with an itearble of:
34{
35 "band": {
36 "center_averaging_type": "geometric",
37 ...
38 "index_min": 25
39 }
40}
41...
42{
43 "band": {
44 "center_averaging_type": "geometric",
45 ...
46 "index_min": 25
47 }
48}
50Will start by plugging this into the DecimationLevel.
52TODO: Determine if this class, which represents a single element of a list
53of channel weight specs, which will be in the json, should have a wrapper or not.
55In the same way that a DecimationLevel has Bands,
56it will also have ChannelWeightSpecs.
57"""
59# =====================================================
60# Imports
61# =====================================================
62from typing import Annotated
64import numpy as np
65import xarray as xr
66from pydantic import Field, field_validator, ValidationInfo
68from mt_metadata.base import MetadataBase
69from mt_metadata.common.band import Band
70from mt_metadata.common.enumerations import StrEnumerationBase
71from mt_metadata.features.weights.feature_weight_spec import FeatureWeightSpec
74# =====================================================
75class CombinationStyleEnum(StrEnumerationBase):
76 multiplication = "multiplication"
77 minimum = "minimum"
78 maximum = "maximum"
79 mean = "mean"
82class ChannelWeightSpec(MetadataBase):
83 combination_style: Annotated[
84 CombinationStyleEnum,
85 Field(
86 default="multiplication",
87 description="How to combine multiple feature weights.",
88 alias=None,
89 json_schema_extra={
90 "units": None,
91 "required": True,
92 "examples": ["multiplication"],
93 },
94 ),
95 ]
97 output_channels: Annotated[
98 list[str],
99 Field(
100 default_factory=list,
101 description="list of tf ouput channels for which this weighting scheme will be applied",
102 alias=None,
103 json_schema_extra={
104 "units": None,
105 "required": True,
106 "examples": ["[ ex ey hz ]"],
107 },
108 ),
109 ]
111 feature_weight_specs: Annotated[
112 list[FeatureWeightSpec],
113 Field(
114 default_factory=list,
115 description="List of feature weighting schemes to use for TF processing.",
116 alias=None,
117 json_schema_extra={
118 "units": None,
119 "required": True,
120 "examples": ["[]"],
121 },
122 ),
123 ]
125 weights: Annotated[
126 xr.DataArray | xr.Dataset | np.ndarray | None,
127 Field(
128 default=None,
129 description="Weights computed for this channel weight spec. Should be set after evaluation.",
130 json_schema_extra={
131 "units": None,
132 "required": False,
133 "examples": ["null"],
134 },
135 ),
136 ]
138 @field_validator("feature_weight_specs", mode="before")
139 @classmethod
140 def check_feature_weight_specs(cls, value, info: ValidationInfo):
141 if not isinstance(value, list):
142 value = [value]
144 result = []
145 for item in value:
146 if isinstance(item, FeatureWeightSpec):
147 result.append(item)
148 elif isinstance(item, dict):
149 # Construct directly to ensure validators run
150 result.append(FeatureWeightSpec(**item))
151 else:
152 raise TypeError(f"Expected FeatureWeightSpec or dict, got {type(item)}")
154 return result
156 @field_validator("weights", mode="before")
157 @classmethod
158 def check_weights(cls, value):
159 if not isinstance(
160 value, (xr.DataArray, xr.Dataset, np.ndarray, None.__class__)
161 ):
162 raise TypeError("Data must be a numpy array or xarray.")
164 # QUESTION: do we want to cast the value to a specific type here?
165 return value
167 def evaluate(
168 self, feature_values_dict: dict[str, np.ndarray | float]
169 ) -> float | np.ndarray:
170 """
171 Evaluate the channel weight by combining weights from all features.
173 Parameters
174 ----------
175 feature_values_dict : dict[str, np.ndarray | float]
176 Dictionary mapping feature names to their computed values.
177 e.g., {"coherence": ndarray, "multiple_coherence": ndarray}
179 Returns
180 -------
181 channel_weight : float or np.ndarray
182 """
183 import numpy as np
185 weights = []
186 for feature_weight_spec in self.feature_weight_specs:
187 fname = feature_weight_spec.feature.name
188 if fname not in feature_values_dict:
189 raise KeyError(f"Feature values missing for '{fname}'")
191 w = feature_weight_spec.evaluate(feature_values_dict[fname])
192 weights.append(w)
194 if not weights:
195 return 1.0
197 combo = self.combination_style
198 if combo == "multiplication":
199 return np.prod(weights, axis=0)
200 elif combo == "mean":
201 return np.mean(weights, axis=0)
202 elif combo == "minimum":
203 return np.min(weights, axis=0)
204 elif combo == "maximum":
205 return np.max(weights, axis=0)
206 else:
207 raise ValueError(f"Unknown combination style: {combo}")
209 def get_weights_for_band(self, band: Band) -> np.ndarray | xr.DataArray:
210 """
211 Extract weights for the frequency bin closest to the band's center frequency.
213 TODO: Add tests.
214 Parameters
215 ----------
216 band : Band
217 Should have a .center_frequency attribute (float, Hz).
219 Returns
220 -------
221 weights : np.ndarray or xarray.DataArray
222 Weights for the closest frequency bin.
223 """
224 if self.weights is None:
225 raise ValueError("No weights have been set.")
227 # Assume weights is an xarray.DataArray or Dataset with a 'frequency' dimension
228 freq_axis = None
229 if hasattr(self.weights, "dims"):
230 # Try to find the frequency dimension
231 for dim in self.weights.dims:
232 if "freq" in dim:
233 freq_axis = dim
234 break
235 if freq_axis is None:
236 raise ValueError("Could not find frequency dimension in weights.")
238 freqs = self.weights[freq_axis].values
239 elif isinstance(self.weights, np.ndarray):
240 # If it's a plain ndarray, assume first axis is frequency
241 freqs = np.arange(self.weights.shape[0])
242 freq_axis = 0
243 else:
244 raise TypeError(
245 "Weights must be an xarray.DataArray, Dataset, or numpy array."
246 )
248 # Find index of closest frequency
249 idx = np.argmin(np.abs(freqs - band.center_frequency))
251 # Extract weights for that frequency
252 if hasattr(self.weights, "isel"):
253 # xarray: use isel
254 weights_for_band = self.weights.isel({freq_axis: idx})
255 else:
256 # numpy: index along first axis
257 weights_for_band = self.weights[idx]
259 return weights_for_band