Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ features \ weights \ channel_weight_spec.py: 97%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1""" 

2Container for weighting strategy to apply to a single tf estimation 

3having a single output channel (usually one of "ex", "ey", "hz"). 

4 

5candidate data structure is stored in test_helpers/channel_weight_specs_example.json 

6 

7Candidate names: processing_weights, feature_weights, channel_weights_spec, channel_weighting 

8 

9Notes, and doc for weights PR. 

10 

11channel_weight_specs is a candidate name for the json block like the following: 

12>>> diff processing_configuration_template.json test_processing_config_with_weights_block.json 

13(Another candidate name could be `processing_weights`, or `weights`, but the final nomenclature 

14can be sorted out after there is a functional prototype with the appropriate structure.) 

15 

16 

17This block is basically a dict that maps an output channel name to a ChannelWeightSpec (CWS) object. 

18 

19There are at least three places we would like to be able to plug in such a dict to the processing flow. 

201. At the frequency_band level, so that each band can be associated with a specialty CWS 

212. At the decimation_level level, so that all bands in a GIB have a common, default. 

223. At a high level, so that all processing uses them. 

23TAI: In future, hopefully we could insert a custom CWS for a specific band, but leave 

24all other bands to use the DecimationLevel default CWS, for example. i.e. the CWS can 

25be defined for different scopes. 

26 

27TODO FIXME: IN mt_metadata/transfer_functions/processing/auaora/processing.py 

28when you output a json, it looks like the `decimations` level should be named: 

29`decimation_levels` instead. 

30 

31The general model I'll try to follow will be to open an itearable of objects 

32with a plural of the object name. For example, the processing block called "bands" 

33follows with an itearble of: 

34{ 

35 "band": { 

36 "center_averaging_type": "geometric", 

37 ... 

38 "index_min": 25 

39 } 

40} 

41... 

42{ 

43 "band": { 

44 "center_averaging_type": "geometric", 

45 ... 

46 "index_min": 25 

47 } 

48} 

49 

50Will start by plugging this into the DecimationLevel. 

51 

52TODO: Determine if this class, which represents a single element of a list 

53of channel weight specs, which will be in the json, should have a wrapper or not. 

54 

55In the same way that a DecimationLevel has Bands, 

56it will also have ChannelWeightSpecs. 

57""" 

58 

59# ===================================================== 

60# Imports 

61# ===================================================== 

62from typing import Annotated 

63 

64import numpy as np 

65import xarray as xr 

66from pydantic import Field, field_validator, ValidationInfo 

67 

68from mt_metadata.base import MetadataBase 

69from mt_metadata.common.band import Band 

70from mt_metadata.common.enumerations import StrEnumerationBase 

71from mt_metadata.features.weights.feature_weight_spec import FeatureWeightSpec 

72 

73 

74# ===================================================== 

75class CombinationStyleEnum(StrEnumerationBase): 

76 multiplication = "multiplication" 

77 minimum = "minimum" 

78 maximum = "maximum" 

79 mean = "mean" 

80 

81 

82class ChannelWeightSpec(MetadataBase): 

83 combination_style: Annotated[ 

84 CombinationStyleEnum, 

85 Field( 

86 default="multiplication", 

87 description="How to combine multiple feature weights.", 

88 alias=None, 

89 json_schema_extra={ 

90 "units": None, 

91 "required": True, 

92 "examples": ["multiplication"], 

93 }, 

94 ), 

95 ] 

96 

97 output_channels: Annotated[ 

98 list[str], 

99 Field( 

100 default_factory=list, 

101 description="list of tf ouput channels for which this weighting scheme will be applied", 

102 alias=None, 

103 json_schema_extra={ 

104 "units": None, 

105 "required": True, 

106 "examples": ["[ ex ey hz ]"], 

107 }, 

108 ), 

109 ] 

110 

111 feature_weight_specs: Annotated[ 

112 list[FeatureWeightSpec], 

113 Field( 

114 default_factory=list, 

115 description="List of feature weighting schemes to use for TF processing.", 

116 alias=None, 

117 json_schema_extra={ 

118 "units": None, 

119 "required": True, 

120 "examples": ["[]"], 

121 }, 

122 ), 

123 ] 

124 

125 weights: Annotated[ 

126 xr.DataArray | xr.Dataset | np.ndarray | None, 

127 Field( 

128 default=None, 

129 description="Weights computed for this channel weight spec. Should be set after evaluation.", 

130 json_schema_extra={ 

131 "units": None, 

132 "required": False, 

133 "examples": ["null"], 

134 }, 

135 ), 

136 ] 

137 

138 @field_validator("feature_weight_specs", mode="before") 

139 @classmethod 

140 def check_feature_weight_specs(cls, value, info: ValidationInfo): 

141 if not isinstance(value, list): 

142 value = [value] 

143 

144 result = [] 

145 for item in value: 

146 if isinstance(item, FeatureWeightSpec): 

147 result.append(item) 

148 elif isinstance(item, dict): 

149 # Construct directly to ensure validators run 

150 result.append(FeatureWeightSpec(**item)) 

151 else: 

152 raise TypeError(f"Expected FeatureWeightSpec or dict, got {type(item)}") 

153 

154 return result 

155 

156 @field_validator("weights", mode="before") 

157 @classmethod 

158 def check_weights(cls, value): 

159 if not isinstance( 

160 value, (xr.DataArray, xr.Dataset, np.ndarray, None.__class__) 

161 ): 

162 raise TypeError("Data must be a numpy array or xarray.") 

163 

164 # QUESTION: do we want to cast the value to a specific type here? 

165 return value 

166 

167 def evaluate( 

168 self, feature_values_dict: dict[str, np.ndarray | float] 

169 ) -> float | np.ndarray: 

170 """ 

171 Evaluate the channel weight by combining weights from all features. 

172 

173 Parameters 

174 ---------- 

175 feature_values_dict : dict[str, np.ndarray | float] 

176 Dictionary mapping feature names to their computed values. 

177 e.g., {"coherence": ndarray, "multiple_coherence": ndarray} 

178 

179 Returns 

180 ------- 

181 channel_weight : float or np.ndarray 

182 """ 

183 import numpy as np 

184 

185 weights = [] 

186 for feature_weight_spec in self.feature_weight_specs: 

187 fname = feature_weight_spec.feature.name 

188 if fname not in feature_values_dict: 

189 raise KeyError(f"Feature values missing for '{fname}'") 

190 

191 w = feature_weight_spec.evaluate(feature_values_dict[fname]) 

192 weights.append(w) 

193 

194 if not weights: 

195 return 1.0 

196 

197 combo = self.combination_style 

198 if combo == "multiplication": 

199 return np.prod(weights, axis=0) 

200 elif combo == "mean": 

201 return np.mean(weights, axis=0) 

202 elif combo == "minimum": 

203 return np.min(weights, axis=0) 

204 elif combo == "maximum": 

205 return np.max(weights, axis=0) 

206 else: 

207 raise ValueError(f"Unknown combination style: {combo}") 

208 

209 def get_weights_for_band(self, band: Band) -> np.ndarray | xr.DataArray: 

210 """ 

211 Extract weights for the frequency bin closest to the band's center frequency. 

212 

213 TODO: Add tests. 

214 Parameters 

215 ---------- 

216 band : Band 

217 Should have a .center_frequency attribute (float, Hz). 

218 

219 Returns 

220 ------- 

221 weights : np.ndarray or xarray.DataArray 

222 Weights for the closest frequency bin. 

223 """ 

224 if self.weights is None: 

225 raise ValueError("No weights have been set.") 

226 

227 # Assume weights is an xarray.DataArray or Dataset with a 'frequency' dimension 

228 freq_axis = None 

229 if hasattr(self.weights, "dims"): 

230 # Try to find the frequency dimension 

231 for dim in self.weights.dims: 

232 if "freq" in dim: 

233 freq_axis = dim 

234 break 

235 if freq_axis is None: 

236 raise ValueError("Could not find frequency dimension in weights.") 

237 

238 freqs = self.weights[freq_axis].values 

239 elif isinstance(self.weights, np.ndarray): 

240 # If it's a plain ndarray, assume first axis is frequency 

241 freqs = np.arange(self.weights.shape[0]) 

242 freq_axis = 0 

243 else: 

244 raise TypeError( 

245 "Weights must be an xarray.DataArray, Dataset, or numpy array." 

246 ) 

247 

248 # Find index of closest frequency 

249 idx = np.argmin(np.abs(freqs - band.center_frequency)) 

250 

251 # Extract weights for that frequency 

252 if hasattr(self.weights, "isel"): 

253 # xarray: use isel 

254 weights_for_band = self.weights.isel({freq_axis: idx}) 

255 else: 

256 # numpy: index along first axis 

257 weights_for_band = self.weights[idx] 

258 

259 return weights_for_band