Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ processing \ fourier_coefficients \ decimation.py: 83%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4from collections import OrderedDict 

5from typing import Annotated, List, Optional 

6 

7import numpy as np 

8from loguru import logger 

9from pydantic import Field, field_validator, model_validator, ValidationInfo 

10 

11from mt_metadata.base import MetadataBase 

12from mt_metadata.common import ListDict, TimePeriod 

13from mt_metadata.processing import ShortTimeFourierTransform, TimeSeriesDecimation 

14from mt_metadata.processing.fourier_coefficients.fc_channel import FCChannel 

15 

16 

17# ===================================================== 

18class Decimation(MetadataBase): 

19 id: Annotated[ 

20 str, 

21 Field( 

22 default="", 

23 description="Decimation level ID", 

24 alias=None, 

25 json_schema_extra={ 

26 "units": None, 

27 "required": True, 

28 "examples": ["1"], 

29 }, 

30 ), 

31 ] 

32 

33 channels_estimated: Annotated[ 

34 list[str], 

35 Field( 

36 default_factory=list, 

37 description="list of channels", 

38 alias=None, 

39 json_schema_extra={ 

40 "units": None, 

41 "required": True, 

42 "examples": ["[ex, hy]"], 

43 }, 

44 ), 

45 ] 

46 

47 time_period: Annotated[ 

48 TimePeriod, 

49 Field( 

50 default_factory=TimePeriod, # type: ignore 

51 description="Time period over which these FCs were estimated", 

52 alias=None, 

53 json_schema_extra={ 

54 "units": None, 

55 "required": True, 

56 "examples": ["TimePeriod()"], 

57 }, 

58 ), 

59 ] 

60 

61 channels: Annotated[ 

62 ListDict, 

63 Field( 

64 default_factory=ListDict, 

65 description="List of channels", 

66 alias=None, 

67 json_schema_extra={ 

68 "units": None, 

69 "required": True, 

70 "examples": ["[ex, hy]"], 

71 }, 

72 ), 

73 ] 

74 

75 time_series_decimation: Annotated[ 

76 TimeSeriesDecimation, 

77 Field( 

78 default_factory=TimeSeriesDecimation, # type: ignore 

79 description="Time series decimation settings", 

80 alias=None, 

81 json_schema_extra={ 

82 "units": None, 

83 "required": True, 

84 "examples": ["TimeSeriesDecimation()"], 

85 }, 

86 ), 

87 ] 

88 

89 short_time_fourier_transform: Annotated[ 

90 ShortTimeFourierTransform, 

91 Field( 

92 default_factory=ShortTimeFourierTransform, # type: ignore 

93 description="Short time Fourier transform settings", 

94 alias=None, 

95 json_schema_extra={ 

96 "units": None, 

97 "required": True, 

98 "examples": ["ShortTimeFourierTransform()"], 

99 }, 

100 ), 

101 ] 

102 

103 @field_validator("short_time_fourier_transform", mode="before") 

104 @classmethod 

105 def validate_short_time_fourier_transform( 

106 cls, value: ShortTimeFourierTransform, info: ValidationInfo 

107 ) -> ShortTimeFourierTransform: 

108 if not isinstance(value, ShortTimeFourierTransform): 

109 msg = f"Input must be metadata ShortTimeFourierTransform not {type(value)}" 

110 raise TypeError(msg) 

111 if value.per_window_detrend_type: 

112 msg = f"per_window_detrend_type was set to {value.per_window_detrend_type}" 

113 msg += "however, this is not supported -- setting to empty string" 

114 logger.debug(msg) 

115 value.per_window_detrend_type = "" 

116 return value 

117 

118 @field_validator("channels_estimated", mode="before") 

119 @classmethod 

120 def validate_channels_estimated( 

121 cls, value: list[str], info: ValidationInfo 

122 ) -> list[str]: 

123 if not isinstance(value, list): 

124 msg = f"Input must be a list of strings not {type(value)}" 

125 raise TypeError(msg) 

126 for item in value: 

127 if not isinstance(item, str): 

128 msg = f"All items in the list must be strings not {type(item)}" 

129 raise TypeError(msg) 

130 return value 

131 

132 @field_validator("channels", mode="before") 

133 @classmethod 

134 def validate_channels(cls, value: ListDict, info: ValidationInfo) -> ListDict: 

135 # Handle None values first 

136 if value is None: 

137 return ListDict() 

138 

139 # Handle string representations that might come from HDF5 storage 

140 if isinstance(value, str): 

141 # If it's a string representation, try to parse it or return empty ListDict 

142 if value in ["", "none", "None", "ListDict()", "{}"]: 

143 return ListDict() 

144 # For other string values, try to maintain backward compatibility 

145 logger.warning(f"Converting string representation of channels: {value}") 

146 return ListDict() 

147 

148 if not isinstance(value, (list, tuple, dict, ListDict, OrderedDict)): 

149 msg = ( 

150 "input ch_list must be an iterable, should be a list or dict " 

151 f"not {type(value)}" 

152 ) 

153 logger.error(msg) 

154 raise TypeError(msg) 

155 

156 fails = [] 

157 channels = ListDict() 

158 if isinstance(value, (dict, ListDict, OrderedDict)): 

159 value_list = value.values() 

160 

161 elif isinstance(value, (list, tuple)): 

162 value_list = value 

163 

164 for ii, channel in enumerate(value_list): 

165 try: 

166 ch = FCChannel() 

167 if hasattr(channel, "to_dict"): 

168 channel = channel.to_dict() 

169 ch.from_dict(channel) 

170 channels.append(ch) 

171 except Exception as error: 

172 msg = "Could not create channel from dictionary: %s" 

173 fails.append(msg % error) 

174 logger.error(msg, error) 

175 

176 if len(fails) > 0: 

177 raise TypeError("\n".join(fails)) 

178 

179 return channels 

180 

181 @model_validator(mode="after") 

182 def validate_channels_consistency(self): 

183 """ 

184 Ensure that channels_estimated and channels are synchronized. 

185 

186 - If a channel name exists in channels_estimated but not in channels, 

187 create a new FCChannel with that component name 

188 - Ensure all channels in channels ListDict have their component names 

189 in channels_estimated 

190 """ 

191 channels_estimated = self.channels_estimated 

192 channels = self.channels 

193 

194 # Get existing channel component names from the channels ListDict 

195 existing_channel_names = set(channels.keys()) if channels.keys() else set() 

196 

197 # Get the set of estimated channel names 

198 estimated_channel_names = ( 

199 set(channels_estimated) if channels_estimated else set() 

200 ) 

201 

202 # Find channels that are estimated but don't exist in channels ListDict 

203 missing_channels = estimated_channel_names - existing_channel_names 

204 

205 # Create FCChannel objects for missing channels 

206 for channel_name in missing_channels: 

207 logger.info(f"Creating FCChannel for estimated channel: {channel_name}") 

208 new_channel = FCChannel(component=channel_name) 

209 channels.append(new_channel) 

210 

211 # Find channels in ListDict that aren't in channels_estimated and add them 

212 extra_channels = existing_channel_names - estimated_channel_names 

213 if extra_channels: 

214 logger.info(f"Adding channels to channels_estimated: {extra_channels}") 

215 # Add the extra channel names to channels_estimated 

216 self.channels_estimated.extend(list(extra_channels)) 

217 

218 return self 

219 

220 def add(self, other): 

221 """ 

222 

223 :param other: 

224 :return: 

225 """ 

226 if isinstance(other, Decimation): 

227 self.channels.extend(other.channels) 

228 

229 return self 

230 else: 

231 msg = f"Can only merge ch objects, not {type(other)}" 

232 logger.error(msg) 

233 raise TypeError(msg) 

234 

235 # ----- Begin (Possibly Temporary) methods for integrating TimeSeriesDecimation, STFT Classes -----# 

236 

237 @property 

238 def decimation(self) -> TimeSeriesDecimation: 

239 """ 

240 Passthrough method to access self.time_series_decimation 

241 """ 

242 return self.time_series_decimation 

243 

244 @property 

245 def stft(self): 

246 return self.short_time_fourier_transform 

247 

248 # ----- End (Possibly Temporary) methods for integrating TimeSeriesDecimation, STFT Classes -----# 

249 

250 def update(self, other, match=[]): 

251 """ 

252 Update attribute values from another like element, skipping None 

253 

254 :param other: DESCRIPTION 

255 :type other: TYPE 

256 :return: DESCRIPTION 

257 :rtype: TYPE 

258 

259 """ 

260 if not isinstance(other, type(self)): 

261 logger.warning("Cannot update %s with %s", type(self), type(other)) 

262 for k in match: 

263 if self.get_attr_from_name(k) != other.get_attr_from_name(k): 

264 msg = "%s is not equal %s != %s" 

265 logger.error( 

266 msg, 

267 k, 

268 self.get_attr_from_name(k), 

269 other.get_attr_from_name(k), 

270 ) 

271 raise ValueError( 

272 msg, 

273 k, 

274 self.get_attr_from_name(k), 

275 other.get_attr_from_name(k), 

276 ) 

277 for k, v in other.to_dict(single=True).items(): 

278 if hasattr(v, "size"): 

279 if v.size > 0: 

280 self.update_attribute(k, v) 

281 else: 

282 if v not in [None, 0.0, [], "", "1980-01-01T00:00:00+00:00"]: 

283 self.update_attribute(k, v) 

284 

285 ## Need this because channels are set when setting channels_recorded 

286 ## and it initiates an empty channel, but we need to fill it with 

287 ## the appropriate metadata. 

288 for ch in other.channels: 

289 self.add_channel(ch) 

290 

291 def has_channel(self, component: str) -> bool: 

292 """ 

293 Check to see if the channel already exists 

294 

295 :param component: channel component to look for 

296 :type component: string 

297 :return: True if found, False if not 

298 :rtype: boolean 

299 

300 """ 

301 

302 if component in self.channels_estimated: 

303 return True 

304 return False 

305 

306 def channel_index(self, component): 

307 """ 

308 get index of the channel in the channel list 

309 """ 

310 if self.has_channel(component): 

311 return self.channels_estimated.index(component) 

312 

313 def get_channel(self, component: str) -> FCChannel | None: 

314 """ 

315 Get a channel 

316 

317 :param component: channel component to look for 

318 :type component: string 

319 :return: FCChannel object based on channel type 

320 :rtype: :class:`mt_metadata.timeseries.Channel` 

321 

322 """ 

323 

324 if self.has_channel(component): 

325 return self.channels[component] 

326 

327 def add_channel(self, channel_obj: FCChannel) -> None: 

328 """ 

329 Add a channel to the list, check if one exists if it does overwrite it 

330 

331 :param channel_obj: channel object to add 

332 :type channel_obj: :class:`mt_metadata.transfer_functions.processing.fourier_coefficients.Channel` 

333 

334 """ 

335 if not isinstance(channel_obj, (FCChannel)): 

336 msg = f"Input must be metadata FCChannel not {type(channel_obj)}" 

337 logger.error(msg) 

338 raise ValueError(msg) 

339 

340 if self.has_channel(channel_obj.component): 

341 self.channels[channel_obj.component].update(channel_obj) 

342 logger.debug( 

343 f"ch {channel_obj.component} already exists, updating metadata" 

344 ) 

345 

346 else: 

347 self.channels.append(channel_obj) 

348 

349 self.update_time_period() 

350 

351 def remove_channel(self, channel_id: str) -> None: 

352 """ 

353 remove a channel from the survey 

354 

355 :param component: channel component to look for 

356 :type component: string 

357 

358 """ 

359 

360 if self.has_channel(channel_id): 

361 self.channels.remove(channel_id) 

362 self.channels_estimated.remove(channel_id) 

363 else: 

364 logger.warning(f"Could not find {channel_id} to remove.") 

365 

366 self.update_time_period() 

367 

368 @property 

369 def n_channels(self): 

370 return len(self.channels) 

371 

372 def update_time_period(self): 

373 """ 

374 update time period from ch information 

375 """ 

376 start = [] 

377 end = [] 

378 for ch in self.channels: 

379 if ch.time_period.start != "1980-01-01T00:00:00+00:00": 

380 start.append(ch.time_period.start) 

381 if ch.time_period.start != "1980-01-01T00:00:00+00:00": 

382 end.append(ch.time_period.end) 

383 if start: 

384 if self.time_period.start == "1980-01-01T00:00:00+00:00": 

385 self.time_period.start = min(start) 

386 else: 

387 if self.time_period.start > min(start): 

388 self.time_period.start = min(start) 

389 if end: 

390 if self.time_period.end == "1980-01-01T00:00:00+00:00": 

391 self.time_period.end = max(end) 

392 else: 

393 if self.time_period.end < max(end): 

394 self.time_period.end = max(end) 

395 

396 def is_valid_for_time_series_length(self, n_samples_ts: int) -> bool: 

397 """ 

398 Given a time series of len n_samples_ts, checks if there are sufficient samples to STFT. 

399 

400 """ 

401 required_num_samples = ( 

402 self.stft.window.num_samples 

403 + (self.stft.min_num_stft_windows - 1) 

404 * self.stft.window.num_samples_advance 

405 ) 

406 if n_samples_ts < required_num_samples: 

407 msg = ( 

408 f"{n_samples_ts} not enough samples for minimum of " 

409 f"{self.stft.min_num_stft_windows} stft windows of length " 

410 f"{self.stft.window.num_samples} and overlap {self.stft.window.overlap}" 

411 ) 

412 logger.warning(msg) 

413 return False 

414 else: 

415 return True 

416 

417 @property 

418 def fft_frequencies(self) -> np.ndarray: 

419 """Returns the one-sided fft frequencies (without Nyquist)""" 

420 return self.stft.window.fft_harmonics(self.decimation.sample_rate) 

421 

422 

423def fc_decimations_creator( 

424 initial_sample_rate: float, 

425 decimation_factors: Optional[list] = None, 

426 max_levels: Optional[int] = 6, 

427 time_period: Optional[TimePeriod] = None, 

428) -> List[Decimation]: 

429 """ 

430 

431 Creates mt_metadata FCDecimation objects that parameterize Fourier coefficient decimation levels. 

432 

433 Note 1: This does not yet work through the assignment of which bands to keep. Refer to 

434 mt_metadata.transfer_functions.processing.Processing.assign_bands() to see how this was done in the past 

435 

436 Parameters 

437 ---------- 

438 initial_sample_rate: float 

439 Sample rate of the "level0" data -- usually the sample rate during field acquisition. 

440 decimation_factors: Optional[list] 

441 The decimation factors that will be applied at each FC decimation level 

442 max_levels: Optional[int] 

443 The maximum number of decimation levels to allow 

444 time_period: Optional[TimePeriod] 

445 Provides the start and end times 

446 

447 Returns 

448 ------- 

449 fc_decimations: list 

450 Each element of the list is an object of type 

451 mt_metadata.transfer_functions.processing.fourier_coefficients.Decimation, 

452 (a.k.a. FCDecimation). 

453 

454 The order of the list corresponds the order of the cascading decimation 

455 - No decimation levels are omitted. 

456 - This could be changed in future by using a dict instead of a list, 

457 - e.g. decimation_factors = dict(zip(np.arange(max_levels), decimation_factors)) 

458 

459 """ 

460 if not decimation_factors: 

461 # msg = "No decimation factors given, set default values to EMTF default values [1, 4, 4, 4, ..., 4]") 

462 # logger.info(msg) 

463 default_decimation_factor = 4 

464 decimation_factors = max_levels * [default_decimation_factor] 

465 decimation_factors[0] = 1 

466 

467 # See Note 1 

468 fc_decimations = [] 

469 for i_dec_level, decimation_factor in enumerate(decimation_factors): 

470 fc_dec = Decimation() 

471 fc_dec.time_series_decimation.level = i_dec_level 

472 fc_dec.id = f"{i_dec_level}" 

473 fc_dec.time_series_decimation.factor = decimation_factor 

474 if i_dec_level == 0: 

475 current_sample_rate = 1.0 * initial_sample_rate 

476 else: 

477 current_sample_rate /= decimation_factor 

478 fc_dec.time_series_decimation.sample_rate = current_sample_rate 

479 

480 if time_period: 

481 if isinstance(time_period, TimePeriod): 

482 fc_dec.time_period = time_period 

483 else: 

484 msg = ( 

485 f"Not sure how to assign time_period with type {type(time_period)}" 

486 ) 

487 logger.info(msg) 

488 raise NotImplementedError(msg) 

489 

490 fc_decimations.append(fc_dec) 

491 

492 return fc_decimations 

493 

494 

495def get_degenerate_fc_decimation(sample_rate: float) -> list: 

496 """ 

497 WIP 

498 

499 Makes a default fc_decimation list. 

500 This "degenerate" config will only operate on the first decimation level. 

501 This is useful for testing. It could also be used in future on an MTH5 stored 

502 time series in decimation levels already as separate runs. 

503 

504 Parameters 

505 ---------- 

506 sample_rate: float 

507 The sample rate associated with the time-series to convert to spectrogram 

508 

509 Returns 

510 ------- 

511 output: list 

512 List has only one element which is of type FCDecimation, aka. 

513 

514 """ 

515 output = fc_decimations_creator( 

516 sample_rate, 

517 decimation_factors=[ 

518 1, 

519 ], 

520 max_levels=1, 

521 ) 

522 return output