Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ aurora \ processing.py: 43%

142 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from typing import Annotated 

5 

6from loguru import logger 

7from pydantic import computed_field, Field, field_validator 

8 

9from mt_metadata.base import MetadataBase 

10from mt_metadata.common.band import Band 

11from mt_metadata.common.enumerations import StrEnumerationBase 

12from mt_metadata.processing.aurora.channel_nomenclature import ChannelNomenclature 

13from mt_metadata.processing.aurora.decimation_level import DecimationLevel 

14from mt_metadata.processing.aurora.stations import Stations 

15 

16 

17# ===================================================== 

18class BandSpecificationStyleEnum(StrEnumerationBase): 

19 EMTF = "EMTF" 

20 band_edges = "band_edges" 

21 

22 

23class Processing(MetadataBase): 

24 decimations: Annotated[ 

25 list[DecimationLevel], 

26 Field( 

27 default_factory=list, 

28 description="decimation levels", 

29 alias=None, 

30 json_schema_extra={ 

31 "units": None, 

32 "required": True, 

33 "examples": ["0"], 

34 }, 

35 ), 

36 ] 

37 

38 band_specification_style: Annotated[ 

39 BandSpecificationStyleEnum | None, 

40 Field( 

41 default=None, 

42 description="describes how bands were sourced", 

43 alias=None, 

44 json_schema_extra={ 

45 "units": None, 

46 "required": False, 

47 "examples": ["EMTF"], 

48 }, 

49 ), 

50 ] 

51 

52 band_setup_file: Annotated[ 

53 str | None, 

54 Field( 

55 default=None, 

56 description="the band setup file used to define bands", 

57 alias=None, 

58 json_schema_extra={ 

59 "units": None, 

60 "required": False, 

61 "examples": ["/home/user/bs_test.cfg"], 

62 }, 

63 ), 

64 ] 

65 

66 id: Annotated[ 

67 str, 

68 Field( 

69 default="", 

70 description="Configuration ID", 

71 alias=None, 

72 json_schema_extra={ 

73 "units": None, 

74 "required": True, 

75 "examples": ["0"], 

76 }, 

77 ), 

78 ] 

79 

80 channel_nomenclature: Annotated[ 

81 ChannelNomenclature, 

82 Field( 

83 default_factory=ChannelNomenclature, # type: ignore 

84 description="Channel nomenclature", 

85 alias=None, 

86 json_schema_extra={ 

87 "units": None, 

88 "required": True, 

89 "examples": ["EMTF"], 

90 }, 

91 ), 

92 ] 

93 

94 stations: Annotated[ 

95 Stations, 

96 Field( 

97 default_factory=Stations, # type: ignore 

98 description="Station information", 

99 alias=None, 

100 json_schema_extra={ 

101 "units": None, 

102 "required": True, 

103 "examples": ["Station1", "Station2"], 

104 }, 

105 ), 

106 ] 

107 

108 @field_validator("decimations", mode="before") 

109 @classmethod 

110 def validate_decimations(cls, value, info) -> list[DecimationLevel]: 

111 decimation_levels = [] 

112 if isinstance(value, DecimationLevel): 

113 decimation_levels.append(value) 

114 

115 elif isinstance(value, dict): 

116 for key, obj in value.items(): 

117 if not isinstance(obj, DecimationLevel): 

118 raise TypeError( 

119 f"List entry must be a DecimationLevel object not {type(obj)}" 

120 ) 

121 else: 

122 decimation_levels.append(obj) 

123 

124 elif isinstance(value, list): 

125 for obj in value: 

126 if isinstance(obj, DecimationLevel): 

127 decimation_levels.append(obj) 

128 for obj in value: 

129 if isinstance(obj, DecimationLevel): 

130 decimation_levels.append(obj) 

131 elif isinstance(obj, dict): 

132 level = DecimationLevel() # type: ignore 

133 level.from_dict(obj) 

134 decimation_levels.append(level) 

135 else: 

136 raise TypeError( 

137 f"List entry must be a DecimationLevel or dict object not {type(obj)}" 

138 ) 

139 # TODO: Add some doc describing the role of this weird check for a long string 

140 elif isinstance(value, str): 

141 if len(value) > 4: 

142 raise TypeError(f"Not sure what to do with {type(value)}") 

143 else: 

144 decimation_levels = [] 

145 

146 else: 

147 raise TypeError(f"Not sure what to do with {type(value)}") 

148 

149 return decimation_levels 

150 

151 @computed_field 

152 @property 

153 def decimations_dict(self) -> dict[int, DecimationLevel]: 

154 """ 

155 need to have a dictionary, but it can't be an attribute cause that 

156 gets confusing when reading in a json file 

157 

158 Returns 

159 ------- 

160 dict[int, DecimationLevel] 

161 A dictionary mapping decimation levels to their corresponding DecimationLevel objects. 

162 

163 """ 

164 return dict([(d.decimation.level, d) for d in self.decimations]) 

165 

166 def get_decimation_level(self, level: int) -> DecimationLevel: 

167 """ 

168 Get a decimation level for easy access 

169 

170 Parameters 

171 ---------- 

172 level: int 

173 The decimation level to retrieve. 

174 

175 Returns 

176 ------- 

177 DecimationLevel 

178 The DecimationLevel object corresponding to the specified level. 

179 

180 """ 

181 

182 try: 

183 decimation = self.decimations_dict[level] 

184 

185 except KeyError: 

186 raise KeyError(f"Could not find {level} in decimations.") 

187 

188 if isinstance(decimation, dict): 

189 decimation_level = DecimationLevel() # type: ignore 

190 decimation_level.from_dict(decimation) 

191 return decimation_level 

192 

193 return decimation 

194 

195 def add_decimation_level(self, decimation_level: DecimationLevel | dict): 

196 """ 

197 add a decimation level 

198 

199 Parameters 

200 ---------- 

201 decimation_level: DecimationLevel | dict 

202 The decimation level to add, either as a DecimationLevel object or a dictionary. 

203 Returns 

204 ------- 

205 None 

206 """ 

207 

208 if not isinstance(decimation_level, (DecimationLevel, dict)): 

209 raise TypeError( 

210 f"List entry must be a DecimationLevel object not {type(decimation_level)}" 

211 ) 

212 if isinstance(decimation_level, dict): 

213 obj = DecimationLevel() # type: ignore 

214 obj.from_dict(decimation_level) 

215 

216 else: 

217 obj = decimation_level 

218 

219 self.decimations.append(obj) 

220 

221 @computed_field 

222 @property 

223 def band_edges_dict(self) -> dict[int, list[tuple[float, float]]]: 

224 band_edges_dict = {} 

225 for i_dec, decimation in enumerate(self.decimations): 

226 band_edges_dict[i_dec] = decimation.band_edges 

227 return band_edges_dict 

228 

229 def assign_decimation_level_data_emtf(self, sample_rate: float): 

230 """ 

231 

232 Warning: This does not actually tell us how many samples we are decimating down 

233 at each level. That is assumed to be 4 but we need a way to bookkeep this in general 

234 

235 Parameters 

236 ---------- 

237 sample_rate: float 

238 The initial sampling rate of the data before any decimation 

239 

240 """ 

241 for key in sorted(self.decimations_dict.keys()): 

242 if key in [0, "0"]: 

243 d = 1 

244 sr = sample_rate 

245 else: 

246 # careful with this hardcoded assumption of decimation by 4 

247 d = 4 

248 sr = sample_rate / (d ** int(key)) 

249 decimation_obj = self.decimations_dict[key] 

250 decimation_obj.decimation.factor = d 

251 decimation_obj.decimation.sample_rate = sr 

252 

253 def assign_bands( 

254 self, 

255 band_edges_dict: dict[int, list[tuple[float, float]]], 

256 sample_rate: float, 

257 decimation_factors: dict[int, int], 

258 num_samples_window: dict[int, int] | int = 256, 

259 ) -> None: 

260 """ 

261 

262 Warning: This does not actually tell us how many samples we are decimating down 

263 at each level. That is assumed to be 4 but we need a way to bookkeep this in general 

264 

265 Parameters 

266 ---------- 

267 band_edges: dict[int, list[tuple[float, float]]] 

268 A dictionary mapping decimation levels to lists of frequency band edges. 

269 keys are integers, starting with 0, values are arrays of edges 

270 

271 sample_rate: float 

272 The initial sampling rate of the data before any decimation. 

273 

274 decimation_factors: dict[int, int] 

275 A dictionary mapping decimation levels to their corresponding decimation factors. 

276 

277 num_samples_window: dict[int, int] | int, optional (default=256) 

278 The number of samples in the STFT window for each decimation level. If an integer is provided, 

279 it will be applied to all levels. If a dictionary is provided, it should map decimation levels to 

280 their corresponding number of samples. 

281 

282 Returns 

283 ------- 

284 None 

285 """ 

286 num_decimation_levels = len(band_edges_dict.keys()) 

287 if isinstance(num_samples_window, int): 

288 num_samples_window = num_decimation_levels * [num_samples_window] 

289 

290 for i_level in sorted(band_edges_dict.keys()): 

291 band_edges = band_edges_dict[i_level] 

292 if i_level in [0, "0"]: 

293 d = decimation_factors[i_level] # 1 

294 sr = sample_rate 

295 else: 

296 # careful with this hardcoded assumption of decimation by 4 

297 d = decimation_factors[i_level] # 4 

298 sr = 1.0 * sample_rate / (d ** int(i_level)) 

299 decimation_obj = DecimationLevel() 

300 decimation_obj.decimation.level = int(i_level) # self.decimations_dict[key] 

301 decimation_obj.decimation.factor = d 

302 decimation_obj.decimation.sample_rate = sr 

303 decimation_obj.stft.window.num_samples = num_samples_window[i_level] 

304 frequencies = decimation_obj.fft_frequencies 

305 

306 for low, high in band_edges: 

307 band = Band( # type: ignore 

308 decimation_level=i_level, 

309 frequency_min=low, 

310 frequency_max=high, 

311 ) 

312 band.set_indices_from_frequencies(frequencies) 

313 decimation_obj.add_band(band) 

314 self.add_decimation_level(decimation_obj) 

315 

316 def json_fn(self): 

317 json_fn = self.id + "_processing_config.json" 

318 return json_fn 

319 

320 @property 

321 def num_decimation_levels(self): 

322 return len(self.decimations) 

323 

324 def drop_reference_channels(self): 

325 for decimation in self.decimations: 

326 decimation.reference_channels = [] 

327 return 

328 

329 def set_input_channels(self, channels): 

330 for decimation in self.decimations: 

331 decimation.input_channels = channels 

332 

333 def set_output_channels(self, channels): 

334 for decimation in self.decimations: 

335 decimation.output_channels = channels 

336 

337 def set_reference_channels(self, channels): 

338 for decimation in self.decimations: 

339 decimation.reference_channels = channels 

340 

341 def set_default_input_output_channels(self): 

342 self.set_input_channels(self.channel_nomenclature.default_input_channels) 

343 self.set_output_channels(self.channel_nomenclature.default_output_channels) 

344 

345 def set_default_reference_channels(self): 

346 self.set_reference_channels( 

347 self.channel_nomenclature.default_reference_channels 

348 ) 

349 

350 def validate_processing(self, kernel_dataset): 

351 """ 

352 Placeholder. Some of the checks and methods here maybe better placed in 

353 TFKernel, which would validate the dataset against the processing config. 

354 

355 Things that are validated: 

356 1. The default estimation engine from the json file is "RME_RR", which is fine ( 

357 we expect to in general to do more RR processing than SS) but if there is only 

358 one station (no remote)then the RME_RR should be replaced by default with "RME". 

359 

360 2. make sure local station id is defined (correctly from kernel dataset) 

361 """ 

362 

363 # Make sure a RR method is not being called for a SS config 

364 if not self.stations.remote: 

365 for decimation in self.decimations: 

366 if decimation.estimator.engine == "RME_RR": 

367 logger.info("No RR station specified, switching RME_RR to RME") 

368 decimation.estimator.engine = "RME" 

369 

370 # Make sure that a local station is defined 

371 if not self.stations.local.id: 

372 logger.warning( 

373 "Local station not specified, should be set from Kernel Dataset" 

374 ) 

375 self.stations.from_dataset_dataframe(kernel_dataset.df)