Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ aurora \ decimation_level.py: 79%

224 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1""" 

2This module contains the DecimationLevel class. 

3TODO: Factor or rename. The decimation level class here has information about the entire processing. 

4""" 

5 

6# ===================================================== 

7# Imports 

8# ===================================================== 

9from typing import Annotated, get_args, List, Union 

10 

11import numpy as np 

12import pandas as pd 

13from loguru import logger 

14from pydantic import computed_field, Field, field_validator, ValidationInfo 

15 

16from mt_metadata.base import MetadataBase 

17from mt_metadata.common.band import Band 

18from mt_metadata.common.enumerations import StrEnumerationBase 

19from mt_metadata.helper_functions import cast_to_class_if_dict, validate_setter_input 

20from mt_metadata.processing import ShortTimeFourierTransform as STFT 

21from mt_metadata.processing import TimeSeriesDecimation as Decimation 

22 

23 

24from mt_metadata.features.weights import ChannelWeightSpec 

25 

26from mt_metadata.processing.aurora.estimator import Estimator 

27from mt_metadata.processing.aurora.frequency_bands import FrequencyBands 

28from mt_metadata.processing.aurora.regression import Regression 

29 

30 

31from mt_metadata.processing.fourier_coefficients.decimation import ( 

32 Decimation as FCDecimation, 

33) 

34 

35 

36# ===================================================== 

37class SaveFcsTypeEnum(StrEnumerationBase): 

38 h5 = "h5" 

39 csv = "csv" 

40 

41 

42class DecimationLevel(MetadataBase): 

43 bands: Annotated[ 

44 list[Band], 

45 Field( 

46 default_factory=list, 

47 description="List of bands", 

48 json_schema_extra={ 

49 "units": None, 

50 "required": True, 

51 "examples": ["[]"], 

52 }, 

53 ), 

54 ] 

55 

56 channel_weight_specs: Annotated[ 

57 List[ChannelWeightSpec], 

58 Field( 

59 default_factory=list, 

60 description="List of weighting schemes to use for TF processing for each output channel", 

61 alias=None, 

62 json_schema_extra={ 

63 "units": None, 

64 "required": True, 

65 "examples": ["[]"], 

66 }, 

67 ), 

68 ] 

69 

70 input_channels: Annotated[ 

71 list[str], 

72 Field( 

73 default_factory=list, 

74 description="list of input channels (sources)", 

75 alias=None, 

76 json_schema_extra={ 

77 "units": None, 

78 "required": True, 

79 "examples": ["hx, hy"], 

80 }, 

81 ), 

82 ] 

83 

84 output_channels: Annotated[ 

85 list[str], 

86 Field( 

87 default_factory=list, 

88 description="list of output channels (responses)", 

89 alias=None, 

90 json_schema_extra={ 

91 "units": None, 

92 "required": True, 

93 "examples": ["ex, ey, hz"], 

94 }, 

95 ), 

96 ] 

97 

98 reference_channels: Annotated[ 

99 list[str], 

100 Field( 

101 default_factory=list, 

102 description="list of reference channels (remote sources)", 

103 alias=None, 

104 json_schema_extra={ 

105 "units": None, 

106 "required": True, 

107 "examples": ["hx, hy"], 

108 }, 

109 ), 

110 ] 

111 

112 save_fcs: Annotated[ 

113 bool, 

114 Field( 

115 default=False, 

116 description="Whether the Fourier coefficients are saved [True] or not [False].", 

117 alias=None, 

118 json_schema_extra={ 

119 "units": None, 

120 "required": True, 

121 "examples": [True], 

122 }, 

123 ), 

124 ] 

125 

126 save_fcs_type: Annotated[ 

127 SaveFcsTypeEnum | None, 

128 Field( 

129 default=None, 

130 description="Format to use for fc storage", 

131 alias=None, 

132 json_schema_extra={ 

133 "units": None, 

134 "required": False, 

135 "examples": ["h5"], 

136 }, 

137 ), 

138 ] 

139 

140 decimation: Annotated[ 

141 Decimation, 

142 Field( 

143 default_factory=Decimation, # type: ignore 

144 description="Decimation settings", 

145 alias=None, 

146 json_schema_extra={ 

147 "units": None, 

148 "required": False, 

149 "examples": ["Decimation()"], 

150 }, 

151 ), 

152 ] 

153 

154 estimator: Annotated[ 

155 Estimator, 

156 Field( 

157 default_factory=Estimator, # type: ignore 

158 description="Estimator settings", 

159 alias=None, 

160 json_schema_extra={ 

161 "units": None, 

162 "required": False, 

163 "examples": ["Estimator()"], 

164 }, 

165 ), 

166 ] 

167 

168 regression: Annotated[ 

169 Regression, 

170 Field( 

171 default_factory=Regression, # type: ignore 

172 description="Regression settings", 

173 alias=None, 

174 json_schema_extra={ 

175 "units": None, 

176 "required": False, 

177 "examples": ["Regression()"], 

178 }, 

179 ), 

180 ] 

181 

182 stft: Annotated[ 

183 STFT, 

184 Field( 

185 default_factory=STFT, # type: ignore 

186 description="Short-time Fourier transform settings", 

187 alias=None, 

188 json_schema_extra={ 

189 "units": None, 

190 "required": False, 

191 "examples": ["STFT()"], 

192 }, 

193 ), 

194 ] 

195 

196 @field_validator("channel_weight_specs", mode="before") 

197 @classmethod 

198 def validate_channel_weight_specs(cls, value, info: ValidationInfo): 

199 """ 

200 Validator for channel_weight_specs field. 

201 """ 

202 

203 # Handle singleton cases 

204 if isinstance(value, (ChannelWeightSpec, dict)): 

205 value = [value] 

206 

207 if not isinstance(value, list): 

208 raise TypeError(f"Not sure what to do with {type(value)}") 

209 

210 # Convert dicts to ChannelWeightSpecs objects 

211 validated_specs = [] 

212 for item in value: 

213 if isinstance(item, dict): 

214 validated_specs.append(ChannelWeightSpec(**item)) 

215 elif isinstance(item, ChannelWeightSpec): 

216 validated_specs.append(item) 

217 else: 

218 raise TypeError( 

219 f"List entry must be a ChannelWeightSpec object or dict, not {type(item)}" 

220 ) 

221 

222 return validated_specs 

223 

224 @field_validator("bands", mode="before") 

225 @classmethod 

226 def validate_bands(cls, value, info: ValidationInfo): 

227 # Get the field type dynamically from the model 

228 field_name = info.field_name 

229 if field_name is None: 

230 raise ValueError("Field name is required for validation") 

231 

232 field_info = cls.model_fields[field_name] 

233 

234 # Extract the target class from List[TargetClass] annotation 

235 target_class = get_args(field_info.annotation)[0] 

236 

237 values = validate_setter_input(value, target_class) 

238 return [cast_to_class_if_dict(obj, target_class) for obj in values] 

239 

240 def add_band(self, band: Union[Band, dict]) -> None: 

241 """ 

242 add a band 

243 """ 

244 

245 if not isinstance(band, (Band, dict)): 

246 raise TypeError(f"List entry must be a Band object not {type(band)}") 

247 if isinstance(band, dict): 

248 obj = Band() 

249 obj.from_dict(band) 

250 else: 

251 obj = band 

252 

253 self.bands.append(obj) 

254 

255 @computed_field 

256 @property 

257 def lower_bounds(self) -> np.ndarray: 

258 """ 

259 get lower bounds index values into an array. 

260 """ 

261 

262 return np.array(sorted([band.index_min for band in self.bands])) 

263 

264 @computed_field 

265 @property 

266 def upper_bounds(self) -> np.ndarray: 

267 """ 

268 get upper bounds index values into an array. 

269 """ 

270 

271 return np.array(sorted([band.index_max for band in self.bands])) 

272 

273 @computed_field 

274 @property 

275 def bands_dataframe(self) -> pd.DataFrame: 

276 """ 

277 Utility function that transforms a list of bands into a dataframe 

278 

279 See notes in `_df_from_bands`. 

280 

281 Returns 

282 ------- 

283 bands_df: pd.Dataframe 

284 Same format as that generated by EMTFBandSetupFile.get_decimation_level() 

285 """ 

286 bands_df = _df_from_bands(self.bands) 

287 return bands_df 

288 

289 @computed_field 

290 @property 

291 def frequency_sample_interval(self) -> float: 

292 """ 

293 Returns the delta_f in frequency domain df = 1 / (N * dt) 

294 Here dt is the sample interval after decimation 

295 

296 Returns 

297 ------- 

298 frequency_sample_interval: float 

299 The frequency sample interval after decimation. 

300 """ 

301 return self.decimation.sample_rate / self.stft.window.num_samples 

302 

303 @computed_field 

304 @property 

305 def band_edges(self) -> np.ndarray: 

306 """ 

307 Returns the band edges as a numpy array 

308 

309 Returns 

310 ------- 

311 band_edges: 2D numpy array, one row per frequency band and two columns 

312 """ 

313 bands_df = self.bands_dataframe 

314 band_edges = np.vstack( 

315 (bands_df.frequency_min.values, bands_df.frequency_max.values) 

316 ).T 

317 return band_edges 

318 

319 def frequency_bands_obj(self) -> FrequencyBands: 

320 """ 

321 Gets a FrequencyBands object that is used as input to processing. 

322 

323 Used by Aurora. 

324 

325 TODO: consider adding .to_frequency_bands() method directly to self.bands 

326 

327 Returns 

328 ------- 

329 frequency_bands: FrequencyBands 

330 A FrequencyBands object that can be used as an iterator for processing. 

331 

332 """ 

333 frequency_bands = FrequencyBands(band_edges=self.band_edges) 

334 return frequency_bands 

335 

336 @property 

337 def fft_frequencies(self) -> np.ndarray: 

338 """ 

339 Gets the harmonics of the STFT. 

340 

341 Returns 

342 ------- 

343 freqs: np.ndarray 

344 The frequencies at which the stft will be available. 

345 """ 

346 freqs = self.stft.window.fft_harmonics(self.decimation.sample_rate) 

347 return freqs 

348 

349 @property 

350 def harmonic_indices(self) -> List[int]: 

351 """ 

352 Loops over all bands and returns a list of the harminic indices. 

353 TODO: Distinguish the bands which are a processing construction vs harmonic indices which are FFT info. 

354 

355 Returns 

356 ------- 

357 return_list: list of integers 

358 The indices of the harmonics that are needed for processing. 

359 """ 

360 return_list = [] 

361 for band in self.bands: 

362 fc_indices = band.harmonic_indices 

363 return_list += fc_indices.tolist() 

364 return_list.sort() 

365 return return_list 

366 

367 @property 

368 def local_channels(self): 

369 return self.input_channels + self.output_channels 

370 

371 def is_consistent_with_archived_fc_parameters( 

372 self, fc_decimation: FCDecimation, remote: bool 

373 ): 

374 """ 

375 Usage: For an already existing spectrogram stored in an MTH5 archive, this compares the metadata 

376 within the archive (fc_decimation) with an aurora decimation level (self), and tells whether the 

377 parameters are in agreement. If True, this allows aurora to skip the calculation of FCs and instead 

378 read them from the archive. 

379 

380 TODO: Merge all checks of TimeSeriesDecimation parameters into a single check. 

381 - e.g. Compress all decimation checks to: assert fc_decimation.decimation == self.decimation 

382 

383 Parameters 

384 ---------- 

385 decimation_level: FCDecimation 

386 metadata describing the parameters used to compute an archived spectrogram 

387 remote: bool 

388 If True, we are looking for reference channels, not local channels in the FCGroup. 

389 

390 Iterates over FCDecimation attributes: 

391 "channels_estimated": to ensure all expected channels are in the group 

392 "decimation.anti_alias_filter": check that the expected AAF was applied 

393 "decimation.sample_rate, 

394 "decimation.method", 

395 "stft.prewhitening_type", 

396 "stft.recoloring", 

397 "stft.pre_fft_detrend_type", 

398 "stft.min_num_stft_windows", 

399 "stft.window", 

400 "stft.harmonic_indices", 

401 Returns 

402 ------- 

403 

404 :return: 

405 """ 

406 # channels_estimated: Checks that the archived spectrogram has the required channels 

407 if remote: 

408 required_channels = self.reference_channels 

409 else: 

410 required_channels = self.local_channels 

411 try: 

412 assert set(required_channels).issubset(fc_decimation.channels_estimated) 

413 except AssertionError: 

414 msg = ( 

415 f"required_channels for processing {required_channels} not available" 

416 f"-- fc channels estimated are {fc_decimation.channels_estimated}" 

417 ) 

418 logger.info(msg) 

419 return False 

420 

421 # anti_alias_filter: Check that the data were filtered the same way 

422 try: 

423 assert ( 

424 fc_decimation.time_series_decimation.anti_alias_filter 

425 == self.decimation.anti_alias_filter 

426 ) 

427 except AssertionError: 

428 cond1 = self.decimation.anti_alias_filter == "default" 

429 cond2 = fc_decimation.time_series_decimation.anti_alias_filter is None 

430 if cond1 & cond2: 

431 pass 

432 else: 

433 msg = ( 

434 "Antialias Filters Not Compatible -- need to add handling for " 

435 f"FCdec {fc_decimation.time_series_decimation.anti_alias_filter} and " 

436 f"processing config:{self.decimation.anti_alias_filter}" 

437 ) 

438 raise NotImplementedError(msg) 

439 

440 # sample_rate 

441 try: 

442 assert ( 

443 fc_decimation.time_series_decimation.sample_rate 

444 == self.decimation.sample_rate 

445 ) 

446 except AssertionError: 

447 msg = ( 

448 f"Sample rates do not agree: fc {fc_decimation.time_series_decimation.sample_rate} differs from " 

449 f"processing config {self.decimation.sample_rate}" 

450 ) 

451 logger.info(msg) 

452 return False 

453 

454 # transform method (fft, wavelet, etc.) 

455 try: 

456 assert ( 

457 fc_decimation.short_time_fourier_transform.method == self.stft.method 

458 ) # FFT, Wavelet, etc. 

459 except AssertionError: 

460 msg = ( 

461 "Transform methods do not agree: " 

462 f"fc {fc_decimation.short_time_fourier_transform.method} != processing config {self.stft.method}" 

463 ) 

464 logger.info(msg) 

465 return False 

466 

467 # prewhitening_type 

468 try: 

469 assert fc_decimation.stft.prewhitening_type == self.stft.prewhitening_type 

470 except AssertionError: 

471 msg = ( 

472 "prewhitening_type does not agree " 

473 f"fc {fc_decimation.stft.prewhitening_type} != processing config {self.stft.prewhitening_type}" 

474 ) 

475 logger.info(msg) 

476 return False 

477 

478 # recoloring 

479 try: 

480 assert fc_decimation.stft.recoloring == self.stft.recoloring 

481 except AssertionError: 

482 msg = ( 

483 "recoloring does not agree " 

484 f"fc {fc_decimation.stft.recoloring} != processing config {self.stft.recoloring}" 

485 ) 

486 logger.info(msg) 

487 return False 

488 

489 # pre_fft_detrend_type 

490 try: 

491 assert ( 

492 fc_decimation.stft.pre_fft_detrend_type 

493 == self.stft.pre_fft_detrend_type 

494 ) 

495 except AssertionError: 

496 msg = ( 

497 "pre_fft_detrend_type does not agree " 

498 f"fc {fc_decimation.stft.pre_fft_detrend_type} != processing config {self.stft.pre_fft_detrend_type}" 

499 ) 

500 logger.info(msg) 

501 return False 

502 

503 # min_num_stft_windows 

504 try: 

505 assert ( 

506 fc_decimation.stft.min_num_stft_windows 

507 == self.stft.min_num_stft_windows 

508 ) 

509 except AssertionError: 

510 msg = ( 

511 "min_num_stft_windows do not agree " 

512 f"fc {fc_decimation.stft.min_num_stft_windows} != processing config {self.stft.min_num_stft_windows}" 

513 ) 

514 logger.info(msg) 

515 return False 

516 

517 # window 

518 try: 

519 assert fc_decimation.stft.window == self.stft.window 

520 except AssertionError: 

521 msg = "window does not agree: " 

522 msg = f"{msg} FC Group: {fc_decimation.stft.window} " 

523 msg = f"{msg} Processing Config {self.stft.window}" 

524 logger.info(msg) 

525 return False 

526 

527 if fc_decimation.stft.harmonic_indices is None: 

528 # harmonic_indices not set, skip this check 

529 pass 

530 elif -1 in fc_decimation.stft.harmonic_indices: 

531 # if harmonic_indices is -1, it means the archive kept all so we can skip this check. 

532 pass 

533 else: 

534 msg = "WIP: harmonic indices in AuroraDecimationlevel are derived from processing bands -- Not robustly tested to compare with FCDecimation" 

535 logger.debug(msg) 

536 harmonic_indices_requested = self.harmonic_indices 

537 fcdec_group_set = set(fc_decimation.stft.harmonic_indices) 

538 processing_set = set(harmonic_indices_requested) 

539 if processing_set.issubset(fcdec_group_set): 

540 pass 

541 else: 

542 msg = ( 

543 f"Processing FC indices {processing_set} is not contained " 

544 f"in FC indices {fcdec_group_set}" 

545 ) 

546 logger.info(msg) 

547 return False 

548 

549 # Getting here means no checks were failed. The FCDecimation supports the processing config 

550 return True 

551 

552 def to_fc_decimation( 

553 self, 

554 remote: bool = False, 

555 ignore_harmonic_indices: bool = True, 

556 ) -> FCDecimation: 

557 """ 

558 Generates a FC Decimation() object for use with FC Layer in mth5. 

559 

560 TODO: this is being tested only in aurora -- move a test to mt_metadata or move the method. 

561 Ignoring for now these properties 

562 "time_period.end": "1980-01-01T00:00:00+00:00", 

563 "time_period.start": "1980-01-01T00:00:00+00:00", 

564 

565 TODO: FIXME: Assignment of TSDecimation can be done in one shot once #235 is addressed. 

566 

567 Parameters 

568 ---------- 

569 remote: bool 

570 If True, use reference channels, if False, use local_channels. We may wish to not pass remote=True when 

571 _building_ FCs however, because then not all channels will get built. 

572 ignore_harmonic_indices: bool 

573 If True, leave harmonic indices at default [-1,], which means all indices. If False, only the specific 

574 harmonic indices needed for processing will be stored. Thus, when building FCs, it maybe best to leave 

575 this as True, that way all FCs will be stored, so if the band setup is changed, the FCs will still be there. 

576 

577 Returns: 

578 fc_dec_obj:mt_metadata.transfer_functions.processing.fourier_coefficients.decimation.Decimation 

579 A decimation object configured for STFT processing 

580 

581 """ 

582 

583 fc_dec_obj = FCDecimation() # type: ignore 

584 fc_dec_obj.time_series_decimation.anti_alias_filter = ( 

585 self.decimation.anti_alias_filter 

586 ) 

587 if remote: 

588 fc_dec_obj.channels_estimated = self.reference_channels 

589 else: 

590 fc_dec_obj.channels_estimated = self.local_channels 

591 fc_dec_obj.time_series_decimation.factor = self.decimation.factor 

592 fc_dec_obj.time_series_decimation.level = self.decimation.level 

593 if ignore_harmonic_indices: 

594 pass 

595 else: 

596 # Now that harmonic_indices is list[int], this should work 

597 fc_dec_obj.stft.harmonic_indices = self.harmonic_indices 

598 fc_dec_obj.id = f"{self.decimation.level}" 

599 fc_dec_obj.stft.method = self.stft.method 

600 fc_dec_obj.stft.pre_fft_detrend_type = self.stft.pre_fft_detrend_type 

601 fc_dec_obj.stft.prewhitening_type = self.stft.prewhitening_type 

602 fc_dec_obj.stft.recoloring = self.stft.recoloring 

603 fc_dec_obj.time_series_decimation.sample_rate = self.decimation.sample_rate 

604 fc_dec_obj.stft.window = self.stft.window 

605 

606 return fc_dec_obj 

607 

608 

609def _df_from_bands(band_list: List[Union[Band, dict, None]]) -> pd.DataFrame: 

610 """ 

611 Utility function that transforms a list of bands into a dataframe 

612 

613 Note: The decimation_level here is +1 to agree with EMTF convention. 

614 Not clear this is really necessary 

615 TODO: Consider making this a method of FrequencyBands() class. 

616 TODO: Check typehint -- should None be allowed value in the band_list? 

617 TODO: Consider adding columns lower_closed, upper_closed to df 

618 

619 Parameters 

620 ---------- 

621 band_list: list 

622 obtained from mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel.bands 

623 

624 Returns 

625 ------- 

626 out_df: pd.Dataframe 

627 Same format as that generated by EMTFBandSetupFile.get_decimation_level() 

628 """ 

629 df_columns = [ 

630 "decimation_level", 

631 "lower_bound_index", 

632 "upper_bound_index", 

633 "frequency_min", 

634 "frequency_max", 

635 ] 

636 n_rows = len(band_list) 

637 df_columns_dict = {} 

638 for col in df_columns: 

639 df_columns_dict[col] = n_rows * [None] 

640 for i_band, band in enumerate(band_list): 

641 df_columns_dict["decimation_level"][i_band] = band.decimation_level + 1 

642 df_columns_dict["lower_bound_index"][i_band] = band.index_min 

643 df_columns_dict["upper_bound_index"][i_band] = band.index_max 

644 df_columns_dict["frequency_min"][i_band] = band.frequency_min 

645 df_columns_dict["frequency_max"][i_band] = band.frequency_max 

646 out_df = pd.DataFrame(data=df_columns_dict) 

647 out_df.sort_values(by="lower_bound_index", inplace=True) 

648 out_df.reset_index(inplace=True, drop=True) 

649 return out_df 

650 

651 

652def get_fft_harmonics(samples_per_window: int, sample_rate: float) -> np.ndarray: 

653 """ 

654 Works for odd and even number of points. 

655 

656 Development notes: 

657 Could be modified with kwargs to support one_sided, two_sided, ignore_dc 

658 ignore_nyquist, and etc. Consider taking FrequencyBands as an argument. 

659 

660 Parameters 

661 ---------- 

662 samples_per_window: integer 

663 Number of samples in a window that will be Fourier transformed. 

664 sample_rate: float 

665 Inverse of time step between samples, 

666 Samples per second 

667 

668 Returns 

669 ------- 

670 harmonic_frequencies: numpy array 

671 The frequencies that the fft will be computed. 

672 These are one-sided (positive frequencies only) 

673 Does not return Nyquist 

674 Does return DC component 

675 """ 

676 n_fft_harmonics = int(samples_per_window / 2) # no bin at Nyquist, 

677 delta_t = 1.0 / sample_rate 

678 harmonic_frequencies = np.fft.fftfreq(samples_per_window, d=delta_t) 

679 harmonic_frequencies = harmonic_frequencies[0:n_fft_harmonics] 

680 return harmonic_frequencies